-
Notifications
You must be signed in to change notification settings - Fork 1
/
runTask.m
53 lines (50 loc) · 2.22 KB
/
runTask.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
function runTask(varargin)
%% Parameters
argParser = inputParser();
argParser.KeepUnmatched = true;
argParser.addParameter('dataPath', fileparts(mfilename('fullpath')), ...
@(p) exist(p, 'dir'));
argParser.addParameter('kfoldValues', [], @(x) ~isempty(x) && isnumeric(x));
argParser.addParameter('kfold', 5, @isnumeric);
argParser.addParameter('getRows', [], @(f) isa(f, 'function_handle'));
argParser.addParameter('getLabels', [], @(f) isa(f, 'function_handle'));
argParser.addParameter('featureExtractors', {}, ...
@(fs) iscell(fs) && ~isempty(fs) ...
&& all(cellfun(@(f) isa(f, 'FeatureExtractor'), fs)));
argParser.addParameter('classifier', @LibsvmClassifierCCV, ...
@(c) isa(c, 'function_handle'));
argParser.addParameter('resultsFilename', ...
[datestr(datetime(), 'yyyy-mm-dd_HH-MM-SS'), '.mat'], @ischar);
argParser.parse(varargin{:});
dataPath = argParser.Results.dataPath;
kfoldValues = argParser.Results.kfoldValues;
kfold = argParser.Results.kfold;
getRows = argParser.Results.getRows;
getLabels = argParser.Results.getLabels;
featureExtractors = argParser.Results.featureExtractors;
assert(~isempty(featureExtractors), 'featureExtractors must not be empty');
classifierConstructor = argParser.Results.classifier;
resultsFilename = argParser.Results.resultsFilename;
%% Setup
% classifiers
classifiers = cellfun(@(featureExtractor) ...
classifierConstructor(featureExtractor), ...
featureExtractors, 'UniformOutput', false);
assert(all(cellfun(@(c) isa(c, 'Classifier'), classifiers)), ...
'classifier must be of type ''Classifier''');
% data
% cross validation
rng(1, 'twister'); % seed, use pseudo random generator for reproducibility
%% Run
evaluateClassifiers = curry(@evaluate, classifiers, getRows, getLabels);
% parallelPoolObject = parpool; % init parallel computing pool
% crossValStream = RandStream('mlfg6331_64');
% reset(crossValStream);
results = crossval(evaluateClassifiers, kfoldValues, 'kfold', kfold);%, ...
% 'Options', statset('UseParallel', true, ...
% 'Streams', crossValStream, 'UseSubstreams', true));
% delete(parallelPoolObject); % teardown pool
resultsFile = [dataPath, '/results/' resultsFilename];
save(resultsFile, 'results');
fprintf('Results stored in ''%s''\n', resultsFile);
end