-
Notifications
You must be signed in to change notification settings - Fork 1
/
trigger_threshold_cost_continuous.m
94 lines (76 loc) · 3.65 KB
/
trigger_threshold_cost_continuous.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
% Copyright (C) 2017 Ben Pearre
%
% This file is part of the Zebra Finch Syllable Detector, syllable-detector-learn.
%
% The Zebra Finch Syllable Detector is free software: you can redistribute it and/or
% modify it under the terms of the GNU Lesser General Public License as published by
% the Free Software Foundation, either version 3 of the License, or (at your option)
% any later version.
%
% The Zebra Finch Syllable Detector is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
% FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
% more details.
%
% You should have received a copy of the GNU Lesser General Public License
% along with the Zebra Finch Syllable Detector. If not, see
% <http://www.gnu.org/licenses/>.
function [cost truepositiverate falsepositiverate ] = trigger_threshold_cost_continuous(threshold, ...
responses, ...
tstep_of_interest_shifted, ...
positive_interval, ...
FALSE_POSITIVE_COST, ...
songs_with_hits);
% responses should be a [[ 1 x ] song x timestep ] array of song responses
% for the relevant output neuron. positive_interval is the interval, in samples,
% around the aligned data that counts as a positive.
if ~exist('FALSE_POSITIVE_COST')
FALSE_POSITIVE_COST = 1;
end
responsest = squeeze(responses);
nresponses = size(responsest, 1);
% For every song on which there's a response and there should be, add a cost for not being at the
% right time.
delaycost = 0;
positive_interval_length = ceil(length(positive_interval)/2); % ceil so that the max delay still costs < 1
for i = 1:nresponses
if songs_with_hits(i)
response_on_i = find(trigger(responses(i, :), threshold), 1);
if ~isnan(response_on_i)
diffval = abs(response_on_i - tstep_of_interest_shifted)/positive_interval_length;
if diffval < 1 & diffval > 0
delaycost = delaycost + diffval;
%disp(sprintf('Adding delaycost %f', diffval));
end
end
end
end
% Cost is (weighting constant times) the number of songs for which there's
% a false positive + the number of songs for which there's a false
% negative.
% NOT TRUE anymore: just count every prediction, with value
% FALSE POSITIVES: One false positive for every song for which there is a
% trigger outside the target area.
% FALSE NEGATIVE SONG: One false negative for every song for which there is
% no trigger inside the target area, and the song is in songs_with_hits
responsest = responsest > threshold;
true_positives = sum(responsest(:, positive_interval), 2);
true_positives = sum((true_positives > 0) & songs_with_hits);
% First, the false negatives in the songs that should have hits:
foo = sum(responsest(:, positive_interval), 2);
false_negatives = sum((foo == 0) & songs_with_hits);
% Compute false positives: Max 1 per song? Or 1 per frame?
false_positives_per_song = false;
% Kill the true positives, and what's left is the false positives:
responsest(find(songs_with_hits), positive_interval) = zeros(sum(songs_with_hits), length(positive_interval));
if false_positives_per_song
% Sum of responses of all (non-true-pos) timesteps in each song
foo = sum(responsest, 2);
% How many songs had positives outside the positive time region
false_positives = sum(foo > 0);
else
false_positives = sum(sum(responsest));
end
cost = FALSE_POSITIVE_COST * false_positives + false_negatives + delaycost;
truepositiverate = true_positives / sum(songs_with_hits);
falsepositiverate = false_positives / (prod(size(responsest)) - (nresponses * length(positive_interval)));