-
Notifications
You must be signed in to change notification settings - Fork 2
/
pst_cv_plot.m
84 lines (59 loc) · 2.04 KB
/
pst_cv_plot.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
function pst_cv_plot(CVDATA,varargin)
% Plots the output of pst_cross_validate
%
sort_param='p_min';
nparams=length(varargin);
if mod(nparams,2)>0
error('Parameters must be specified as parameter/value pairs');
end
for i=1:2:nparams
switch lower(varargin{i})
case 'sort_param'
sort_param=varargin{i+1};
otherwise
end
end
% average cv split with the same parameter value
% sort the data by the parameter of choice
sort_param
sort_data=CVDATA.(sort_param);
sort_data=sort_data(:);
param_values=unique(sort_data);
% get the mean value for each parameter value
mean_val=zeros(length(param_values),1);
std_val=zeros(size(mean_val));
n_val=zeros(size(mean_val));
for i=1:length(param_values)
tmp=-CVDATA.test_logl(sort_data==param_values(i)); % convert to negative logl
tmp2=-CVDATA.train_logl(sort_data==param_values(i));
test.mean_val(i)=mean(tmp);
test.std_val(i)=std(tmp);
test.n_val(i)=length(tmp);
train.mean_val(i)=mean(tmp2);
train.std_val(i)=std(tmp2);
test.n_val(i)=length(tmp2);
end
test.ci(1,:)=test.mean_val-(test.std_val);
test.ci(2,:)=test.mean_val+(test.std_val);
train.ci(1,:)=train.mean_val-(train.std_val);
train.ci(2,:)=train.mean_val+(train.std_val);
%[val,idx]=sort(sort_data(:));
figure();
plot(param_values,test.mean_val,'--ko','color','c','linewidth',3,'markersize',5)
box off
hold on;
for i=1:length(param_values)
h2(1)=line([ param_values(i) param_values(i) ],[ test.mean_val(i) test.ci(2,i) ],'color','c','linewidth',3);
h2(2)=line([ param_values(i) param_values(i) ],[ test.ci(1,i) test.mean_val(i) ],'color','c','linewidth',3);
end
plot(param_values,train.mean_val,'--ko','color','b','linewidth',3,'markersize',5)
box off
hold on;
for i=1:length(param_values)
h2(1)=line([ param_values(i) param_values(i) ],[ train.mean_val(i) train.ci(2,i) ],'color','b','linewidth',3);
h2(2)=line([ param_values(i) param_values(i) ],[ train.ci(1,i) train.mean_val(i) ],'color','b','linewidth',3);
end
set(gca,'xdir','rev','xscale','log');
set(gca,'FontSize',15,'FontName','Helvetica');
ylabel('CV Neg. Log Likelihood');
xlabel('Pmin');