Skip to content

Instantly share code, notes, and snippets.

@brlauuu
Created January 25, 2018 13:48
Show Gist options
  • Save brlauuu/cab4396f48f368df755f9df06483827d to your computer and use it in GitHub Desktop.
Save brlauuu/cab4396f48f368df755f9df06483827d to your computer and use it in GitHub Desktop.
Training CNN using Matlab R2017b NN and CV toolbox
% Load training data.
imageDir = fullfile('training_data');
labelDir = fullfile('label_data');
% Create an image datastore for the images.
imds = imageDatastore(imageDir, 'IncludeSubfolders',true, ...
'LabelSource','foldernames');
% Create a pixelLabelDatastore for the ground truth pixel labels.
classNames = ["good","bad"];
labelIDs = [255 0];
pxds = pixelLabelDatastore(labelDir, classNames, labelIDs);
%%%%%% Visualize training images and ground truth pixel labels.
I = read(imds);
C = read(pxds);
figure
I = imresize(I,5);
L = imresize(uint8(C),5);
imshowpair(I,L,'montage')
% Create a semantic segmentation network. This network uses a simple
% semantic segmentation network based on a downsampling and upsampling
% design.
numFilters = 40;
filterSize = 15;
numClasses = numel(categories(imds.Labels));
layers = [
imageInputLayer([64 64 1])
convolution2dLayer(filterSize,numFilters,'Padding',0)
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(filterSize,numFilters/2,'Padding',0)
maxPooling2dLayer(2,'Stride',2)
reluLayer()
transposedConv2dLayer(filterSize,numFilters,'Stride',2);
convolution2dLayer(1,numClasses);
softmaxLayer()
pixelClassificationLayer()
]
% Setup training options.
opts = trainingOptions('sgdm', ...
'InitialLearnRate', 1e-3, ...
'MaxEpochs', 50, ...
'MiniBatchSize', 100);
% Create a data source for training data.
trainingData = pixelLabelImageSource(imds,pxds);
% Train the network.
net = trainNetwork(trainingData,layers,opts);
mem_net = net
save mem_net
% Read and display a test image.
testImage = imread('fil028.jpg');
figure
imshow(testImage)
% Segment the test image and display the results.
%C = semanticseg(testImage,net);
%B = labeloverlay(testImage,C);
%figure
%imshow(B)
%%%%%% Improve the results if necessary
% The network failed to property segment the triangles and classified every
% pixel as "background". The training appeared to be going well with
% training accuracies greater than 90%. However, the network only learned
% to classify the background class. To understand why this happened, you
% can count the occurrence of each pixel label across the dataset.
% The majority of pixel labels are for the background. The poor results are
% due to the class imbalance. Class imbalance biases the learning process
% in favor of the dominant class. That's why every pixel is classified as
% "background". To fix this, use class weighting to balance the classes.
% There are several methods for computing class weights. One common method
% is inverse frequency weighting where the class weights are the inverse of
% the class frequencies. This increases weight given to under-represented
% classes.
%tbl = countEachLabel(trainingData)
%totalNumberOfPixels = sum(tbl.PixelCount);
%frequency = tbl.PixelCount / totalNumberOfPixels;
%classWeights = 1./frequency
%layers(end) = pixelClassificationLayer('ClassNames',tbl.Name,'ClassWeights',classWeights);
%net = trainNetwork(trainingData,layers,opts);
% Try to segment the test image again.
C = semanticseg(testImage,net);
B = labeloverlay(testImage,C);
figure
imshow(B)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment