(Matlab) YOLOv2를 활용한 Multi objet Detector 2 - Training
on Projects
trainig data의 라벨링이 끝났으면 이제 Network를 training해보자. ResNet-50기반의 YOLOv2를 이용한 Multi Object Detector를 구성하고 훈련시킬 예정이다.
data = load('groundTruth.mat');
trainingdata = gTruth.LabelData;
trainingdata.imageFilename = gTruth.DataSource.Source;
trainingdata = trainingdata(:,[3,2]);
rng(0);
shuffledIdx = randperm(height(trainingdata));
idx = floor(0.8 * length(shuffledIdx) );
trainingIdx = 1:idx;
trainingDataTbl = trainingdata(shuffledIdx(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.2 * length(shuffledIdx) );
validationDataTbl = trainingdata(shuffledIdx(validationIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
%bldsTrain = boxLabelDatastore(trainingDataTbl(:,'Car'));
bldsTrain = boxLabelDatastore(trainingDataTbl(:,'Pedestrian'));
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
%bldsValidation = boxLabelDatastore(validationDataTbl(:,'Car'));
bldsValidation = boxLabelDatastore(validationDataTbl(:,'Pedestrian'));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
inputSize = [227 227 3];
numClasses = width(trainingdata)-1;
trainingDataForEstimation = boxLabelDatastore(trainingdata(:,2:end))
numAnchors = 1;
[anchorBoxes, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors)
featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
lgraph = yolov2Layers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
options = trainingOptions('sgdm', ...
'MiniBatchSize',16, ....
'InitialLearnRate',1e-3, ...
'MaxEpochs',20, ...
'CheckpointPath',tempdir);
[detector,info] = trainYOLOv2ObjectDetector(trainingdata,lgraph,options);
약 1시간 걸려서 training 끝.
아래 코드를 yolov2_detect.mat으로 따로 저장해두고 사용했다.
img = imread('000008.png');
[bboxes,scores,labels] = detect(detector,img);
if(~isempty(bboxes))
img = insertObjectAnnotation(img,'rectangle',bboxes,cellstr(labels));
end
figure
imshow(img)
임의의 이미지에 test
pedestrian에 대해서도 학습했는데, 애초 dataset에 pedestrian이 너무 적어서 제대로 학습되지 못했다.