-
Notifications
You must be signed in to change notification settings - Fork 0
/
constrained_oasisAR1.m
250 lines (223 loc) · 7.7 KB
/
constrained_oasisAR1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
function [c, s, b, g, lam, active_set] = constrained_oasisAR1(y, g, sn, optimize_b,...
optimize_g, decimate, maxIter)
%% Infer the most likely discretized spike train underlying an AR(1) fluorescence trace
% Solves the sparse non-negative deconvolution problem
% min 1/2|c-y|^2 + lam |s|_1 subject to s_t = c_t-g c_{t-1} >=s_min or =0
%% inputs:
% y: T*1 vector, One dimensional array containing the fluorescence intensities
%withone entry per time-bin.
% g: scalar, Parameter of the AR(1) process that models the fluorescence ...
%impulse response.
% sn: scalar, standard deviation of the noise distribution
% optimize_b: bool, optimize baseline if True
% optimize_g: integer, number of large, isolated events to consider for
% optimizing g
% decimate: int, decimation factor for estimating hyper-parameters faster
% on decimated data
% maxIter: int, maximum number of iterations
% active_set: npool x 4 matrix, warm stared active sets
%% outputs
% c: T*1 vector, the inferred denoised fluorescence signal at each time-bin.
% s: T*1 vector, discetized deconvolved neural activity (spikes)
% b: scalar, fluorescence baseline
% g: scalar, parameter of the AR(1) process
% lam: scalar, sparsity penalty parameter
% active_set: npool x 4 matrix, active sets
%% Authors: Pengcheng Zhou, Carnegie Mellon University, 2016
% ported from the Python implementation from Johannes Friedrich
%% References
% Friedrich J et.al., NIPS 2016, Fast Active Set Method for Online Spike Inference from Calcium Imaging
%% input arguments
y = reshape(y, [], 1);
T = length(y);
if ~exist('g', 'var') || isempty(g)
g = estimate_time_constant(y, 1);
end
if ~exist('sn', 'var') || isempty(sn)
sn = GetSn(y);
end
if ~exist('lam', 'var') || isempty(lam); lam = 0; end
if ~exist('optimize_b', 'var') || isempty(optimize_b)
optimize_b = false;
end
if ~exist('optimize_g', 'var') || isempty(optimize_g)
optimize_g = 0;
end
if ~exist('decimate', 'var') || isempty(decimate)
decimate = 1;
else
decimate = max(1, round(decimate));
end
if ~exist('maxIter', 'var') || isempty(maxIter)
maxIter = 10;
end
thresh = sn * sn * T;
lam = 0;
% change parameters due to downsampling
if decimate>1
decimate = 1; %#ok<NASGU>
disp('to be done');
% fluo = y;
% y = resample(y, 1, decimate);
% g = g^decimate;
% thresh = thresh / decimate / decimate;
% T = length(y);
end
g_converged = false;
%% optimize parameters
tol = 1e-4;
% flag_lam = true;
if ~optimize_b %% don't optimize the baseline b
%% initialization
b = 0;
[solution, spks, active_set] = oasisAR1(y, g, lam);
%% iteratively update parameters lambda & g
for miter=1:maxIter
% update g
if and(optimize_g, ~g_converged);
g0 = g;
[solution, active_set, g, spks] = update_g(y, active_set,lam);
if abs(g-g0)/g0 < 1e-3 % g is converged
g_converged = true;
end
end
res = y - solution;
RSS = res' * res;
len_active_set = size(active_set, 1);
if RSS>thresh % constrained form has been found, stop
break;
else
% update lam
update_phi;
lam = lam + dphi;
end
end
else
%% initialization
b = quantile(y, 0.15);
[solution, spks, active_set] = oasisAR1(y-b, g, lam);
update_lam_b;
%% optimize the baseline b and dependends on the optimized g too
g_converged = false;
for miter=1:maxIter
res = y - solution - b;
RSS = res' * res;
len_active_set = size(active_set,1);
if or(abs(RSS-thresh) < tol, sum(solution)<1e-9)
break;
else
%% update b & lamba
update_phi();
update_lam_b();
% update b and g
% update b and g
if and(optimize_g, ~g_converged);
g0 = g;
[solution, active_set, g, spks] = update_g(y-b, active_set,lam);
if abs(g-g0)/g0 < 1e-4;
g_converged = true;
end
end
end
end
end
c = solution;
s = spks;
%% nested functions
function update_phi() % estimate dphi to match the thresholded RSS
zeta = zeros(size(solution));
maxl = max(active_set(:, 4));
h = g.^(0:maxl);
for ii=1:len_active_set
ti = active_set(ii, 3);
li = active_set(ii, 4);
idx = 0:(li-1);
if ii<len_active_set
zeta(ti+idx) = (1-g^li)/ active_set(ii,2) * h(1:li);
else
zeta(ti+idx) = 1/active_set(ii,2) * h(1:li);
end
end
if optimize_b
zeta = zeta - mean(zeta);
tmp_res = res - mean(res);
aa = zeta' * zeta;
bb = tmp_res' * zeta;
cc = tmp_res'*tmp_res - thresh;
dphi = (-bb + sqrt(bb^2-aa*cc)) / aa;
else
aa = zeta'*zeta;
bb = res'*zeta;
cc = RSS-thresh;
dphi = (-bb + sqrt(bb^2-aa*cc)) / aa;
end
if imag(dphi)>1e-9
flag_phi = false;
return;
else
flag_phi = true;
end
active_set(:,1) = active_set(:,1) - dphi*(1-g.^active_set(:,4));
[solution, spks, active_set] = oasisAR1([], g, lam, [], active_set);
end
function update_lam_b() % estimate lambda & b
db = mean(y-solution) - b;
b = b + db;
dlam = -db/(1-g);
lam = lam + dlam;
% correct the last pool
active_set(end,1) = active_set(end,1) - lam*g^(active_set(end,4));
ti = active_set(end,3); li = active_set(end,4); idx = 0:(li-1);
solution(ti+idx) = max(0, active_set(end,1)/active_set(end,2)) * (g.^idx);
end
end
%update the AR coefficient: g
function [c, active_set, g, s] = update_g(y, active_set, lam)
%% inputs:
% y: T*1 vector, One dimensional array containing the fluorescence intensities
%withone entry per time-bin.
% active_set: npools*4 matrix, previous active sets.
% lam: scalar, curret value of sparsity penalty parameter lambda.
%% outputs
% c: T*1 vector
% s: T*1 vector, spike train
% active_set: npool x 4 matrix, active sets
% g: scalar
%% Authors: Pengcheng Zhou, Carnegie Mellon University, 2016
% ported from the Python implementation from Johannes Friedrich
%% References
% Friedrich J et.al., NIPS 2016, Fast Active Set Method for Online Spike Inference from Calcium Imaging
%% initialization
len_active_set = size(active_set, 1); %number of active sets
y = reshape(y,[],1); % fluorescence data
maxl = max(active_set(:, 4)); % maximum ISI
c = zeros(size(y)); % the optimal denoised trace
%% find the optimal g and get the warm started active_set
g = fminbnd(@rss_g, 0, 1);
yp = y - lam*(1-g);
for m=1:len_active_set
tmp_h = exp(log(g)*(0:maxl)'); % response kernel
tmp_hh = cumsum(h.*h); % hh(k) = h(1:k)'*h(1:k)
li = active_set(m, 4);
ti = active_set(m, 3);
idx = ti:(ti+li-1);
active_set(m,1) = (yp(idx))'*tmp_h(1:li);
active_set(m,2) = tmp_hh(li);
end
[c,s,active_set] = oasisAR1(y, g, lam, [], active_set);
%% nested functions
function rss = rss_g(g)
h = exp(log(g)*(0:maxl)'); % response kernel
hh = cumsum(h.*h); % hh(k) = h(1:k)'*h(1:k)
yp = y - lam*(1-g); % include the penalty term
for ii=1:len_active_set
li = active_set(ii, 4);
ti = active_set(ii, 3);
idx = ti:(ti+li-1);
tmp_v = max(yp(idx)' * h(1:li) / hh(li), 0);
c(idx) = tmp_v*h(1:li);
end
res = y-c;
rss = res'*res; % residual sum of squares
end
end