(Matlab) YOLOv2를 활용한 Multi objet Detector 2 - Training

YOLOv2를 활용한 Multi objet Detector 2 - Training

trainig data의 라벨링이 끝났으면 이제 Network를 training해보자. ResNet-50기반의 YOLOv2를 이용한 Multi Object Detector를 구성하고 훈련시킬 예정이다.

data = load('groundTruth.mat');

trainingdata = gTruth.LabelData;
trainingdata.imageFilename = gTruth.DataSource.Source;
trainingdata = trainingdata(:,[3,2]);


rng(0);
shuffledIdx = randperm(height(trainingdata));
idx = floor(0.8 * length(shuffledIdx) );

trainingIdx = 1:idx;
trainingDataTbl = trainingdata(shuffledIdx(trainingIdx),:);

validationIdx = idx+1 : idx + 1 + floor(0.2 * length(shuffledIdx) );
validationDataTbl = trainingdata(shuffledIdx(validationIdx),:);


imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
%bldsTrain = boxLabelDatastore(trainingDataTbl(:,'Car'));
bldsTrain = boxLabelDatastore(trainingDataTbl(:,'Pedestrian'));

imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
%bldsValidation = boxLabelDatastore(validationDataTbl(:,'Car'));
bldsValidation = boxLabelDatastore(validationDataTbl(:,'Pedestrian'));

trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);

inputSize = [227 227 3];
numClasses = width(trainingdata)-1;


trainingDataForEstimation = boxLabelDatastore(trainingdata(:,2:end))
numAnchors = 1;
[anchorBoxes, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors)

featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
lgraph = yolov2Layers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);


options = trainingOptions('sgdm', ...
        'MiniBatchSize',16, ....
        'InitialLearnRate',1e-3, ...
        'MaxEpochs',20, ... 
        'CheckpointPath',tempdir);

[detector,info] = trainYOLOv2ObjectDetector(trainingdata,lgraph,options);

image image 약 1시간 걸려서 training 끝.

image

아래 코드를 yolov2_detect.mat으로 따로 저장해두고 사용했다.

img = imread('000008.png');
[bboxes,scores,labels] = detect(detector,img);

if(~isempty(bboxes))
    img = insertObjectAnnotation(img,'rectangle',bboxes,cellstr(labels));
end
figure
imshow(img)

임의의 이미지에 test image

image

pedestrian에 대해서도 학습했는데, 애초 dataset에 pedestrian이 너무 적어서 제대로 학습되지 못했다.