Create Faster R-CNN DN
Browse files- Faster R-CNN DN +113 -0
Faster R-CNN DN
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
inputSize = [224 224 3];
|
| 2 |
+
|
| 3 |
+
preprossedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
|
| 4 |
+
numAnchors = 3;
|
| 5 |
+
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors)
|
| 6 |
+
|
| 7 |
+
featuresExtractionNetwork = resnet50;
|
| 8 |
+
|
| 9 |
+
featureLayer - "activation_40_relu";
|
| 10 |
+
|
| 11 |
+
numClasses = width(vehicleDataset)-1;
|
| 12 |
+
|
| 13 |
+
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
|
| 14 |
+
|
| 15 |
+
augmentedTrainingData = transform(trainingData,@aumentData);
|
| 16 |
+
|
| 17 |
+
augmentedData = cell(4,1);
|
| 18 |
+
for k = 1:4
|
| 19 |
+
data = read(augmentedTrainingData);
|
| 20 |
+
augmentedData{k} = insertShape)data{1},"rectangle",data{2});
|
| 21 |
+
reset(augmentedTrainingData);
|
| 22 |
+
end
|
| 23 |
+
figure
|
| 24 |
+
montage(augmentedData,BorderSize=10)
|
| 25 |
+
|
| 26 |
+
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
|
| 27 |
+
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
|
| 28 |
+
|
| 29 |
+
data = read(trainingData);
|
| 30 |
+
|
| 31 |
+
I = data{1};
|
| 32 |
+
bbox = data{2};
|
| 33 |
+
annotatedImage = insertShape(I,"rectangle",bbox);
|
| 34 |
+
annotatedImage = imresize(annotatedImage,2);
|
| 35 |
+
figure
|
| 36 |
+
imshow(annotatedImage)
|
| 37 |
+
|
| 38 |
+
// Train Faster R-CNN
|
| 39 |
+
|
| 40 |
+
options = trainingOptions("sgdm",...
|
| 41 |
+
MaxEpochs=10,...
|
| 42 |
+
MiniBatchSize=2,...
|
| 43 |
+
InitialLearnRate=1e-3,...
|
| 44 |
+
CheckpointPatin=tempdir,...
|
| 45 |
+
ValidationData=validationData);
|
| 46 |
+
|
| 47 |
+
if doTraining
|
| 48 |
+
% Train the Faster R-CNN detector.
|
| 49 |
+
% * Adjust NegativeOveralpRange and PositiveOverlapRange to ensure
|
| 50 |
+
% that training samples tightly overlap with ground truth.
|
| 51 |
+
[detector, info] = trainFasterRCNNObjectDetector(training
|
| 52 |
+
NegativeOverlapRange=[0 0.3], ...
|
| 53 |
+
PositiveOverlapRange=[0.6 1]);
|
| 54 |
+
else
|
| 55 |
+
% Load pretrained detector for the example.
|
| 56 |
+
pretrained = load("fasterRCNNResNet50EndToEndVehicleExample.mat");
|
| 57 |
+
detector = pretrained.detetor;
|
| 58 |
+
end
|
| 59 |
+
|
| 60 |
+
I = imread(testDataTbl.imageFilename{3});
|
| 61 |
+
I = imresize(I,inputSize(1:2));
|
| 62 |
+
[bboxes,scores] = detect(detector,I);
|
| 63 |
+
|
| 64 |
+
I = insertObjectAnnotation(I,"rectangle",bboxes,scores);
|
| 65 |
+
figure
|
| 66 |
+
imshow(I)
|
| 67 |
+
|
| 68 |
+
testData = transform(testData,@(data)preprocessData(data,inputSize));
|
| 69 |
+
|
| 70 |
+
detectionResults = detect(detector,testData,...
|
| 71 |
+
Threshold=0.2,...
|
| 72 |
+
MiniBatchSize=4);
|
| 73 |
+
|
| 74 |
+
classID = 1;
|
| 75 |
+
metrics = evaluateObjectDetection(detectionResults,testData);
|
| 76 |
+
precision = metrics.ClassMetrics.Precision{classID};
|
| 77 |
+
recall = metrics.ClassMetrics.Recall{classID};
|
| 78 |
+
|
| 79 |
+
figure
|
| 80 |
+
plot(recall,precision)
|
| 81 |
+
xlabel("Recall")
|
| 82 |
+
ylable("Precision")
|
| 83 |
+
grid on
|
| 84 |
+
title(sprintf("Average Precision = %.2f", metrics.ClassMetrics.mAP(classID)))
|
| 85 |
+
|
| 86 |
+
function data = augmentData(data)
|
| 87 |
+
% Randomly flip images and bounding boxes horizontally.
|
| 88 |
+
tform = randomAffine2d("XReflection",true);
|
| 89 |
+
sz = size(data{1});
|
| 90 |
+
rout = affineOutputView(sz,tform);
|
| 91 |
+
data{1} = imwarp(data{1},tform,"OutputView",rout);
|
| 92 |
+
|
| 93 |
+
% Sanitize boxes, if needed. This helper function is attached as a
|
| 94 |
+
% supporting file. Open the example in MATLAB to open this function.
|
| 95 |
+
data{2} = helperSanitizeBoxes(data{2});
|
| 96 |
+
|
| 97 |
+
% Warp boxes.
|
| 98 |
+
data{2} = bboxwwarp(data{2},tform,rout);
|
| 99 |
+
end
|
| 100 |
+
|
| 101 |
+
function data = preprocessData(data,targetSize)
|
| 102 |
+
% Resize image and bounding boxes to targetSize.
|
| 103 |
+
sz = size(data{1},[1 2]);
|
| 104 |
+
scale = targetSize(1:2)./sz;
|
| 105 |
+
data{1} = imresize(data{1},targetSize(1:2));
|
| 106 |
+
|
| 107 |
+
% Sanitize boxes, if needed. This helper function is attached as a
|
| 108 |
+
% supporting file. Open the example in MATLAB to open this function.
|
| 109 |
+
data{2} = helperSanitizeBoxes(data{2});
|
| 110 |
+
|
| 111 |
+
% Resize boxes.
|
| 112 |
+
data{2} = bboxresize(data{2},scale);
|
| 113 |
+
end
|