-
Notifications
You must be signed in to change notification settings - Fork 0
/
odexample.m
127 lines (107 loc) · 4.11 KB
/
odexample.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;
rng(0)
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * height(vehicleDataset));
trainingIdx = 1:idx;
trainingDataTbl = vehicleDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = vehicleDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = vehicleDataset(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,'vehicle'));
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(validationDataTbl(:,'vehicle'));
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(testDataTbl(:,'vehicle'));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
inputSize = [224 224 3];
% preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
trainingData = transform(trainingData,@scaleGT);
numAnchors = 3;
% return;
anchorBoxes = estimateAnchorBoxes(trainingData,numAnchors)
featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
numClasses = width(vehicleDataset)-1;
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
% augmentedTrainingData = transform(trainingData,@augmentData);
% augmentedData = cell(4,1);
% for k = 1:4
% data = read(augmentedTrainingData);
% augmentedData{k} = insertShape(data{1},'Rectangle',data{2});
% reset(augmentedTrainingData);
% end
% figure
% montage(augmentedData,'BorderSize',10)
options = trainingOptions('sgdm',...
'MaxEpochs',10,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-3,...
'CheckpointPath',tempdir,...
'ValidationData',validationData);
doTraining = false;
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
else
% Load pretrained detector for the example.
pretrained = load('fasterRCNNResNet50EndToEndVehicleExample.mat');
detector = pretrained.detector;
end
I = imread(testDataTbl.imageFilename{2});
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
function data = augmentData(data)
% Randomly flip images and bounding boxes horizontally.
tform = randomAffine2d('XReflection',true);
sz = size(data{1});
rout = affineOutputView(sz,tform);
data{1} = imwarp(data{1},tform,'OutputView',rout);
% Sanitize box data, if needed.
data{2} = helperSanitizeBoxes(data{2}, sz);
% Warp boxes.
data{2} = bboxwarp(data{2},tform,rout);
end
function data = preprocessData(data,targetSize)
% Resize image and bounding boxes to targetSize.
sz = size(data{1},[1 2]);
scale = targetSize(1:2)./sz;
data{1} = imresize(data{1},targetSize(1:2));
% Sanitize box data, if needed.
data{2} = helperSanitizeBoxes(data{2}, sz);
% Resize boxes.
data{2} = bboxresize(data{2},scale);
end
function data = scaleGT(data)
targetSize = [224 224];
% data{1} is the image
scale = targetSize./size(data{1},[1 2]);
data{1} = imresize(data{1},targetSize);
% data{2} is the bounding box
data{2} = bboxresize(data{2},scale);
end