Skip to content

Instantly share code, notes, and snippets.

@ehzawad
Created October 3, 2020 23:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ehzawad/33c13ba346c5be97ae074f7f1638e518 to your computer and use it in GitHub Desktop.
Save ehzawad/33c13ba346c5be97ae074f7f1638e518 to your computer and use it in GitHub Desktop.
parentDir = '~/Desktop/omg_deep_learning/';
dataDir = 'ehza_datasets_COVID';
allImages = imageDatastore(fullfile(parentDir, dataDir),'IncludeSubfolders',true, 'LabelSource', 'foldername');
[imgsTrain, imgsValidation] = splitEachLabel(allImages, 0.80, 'randomized');
disp(['Number of training images: ', num2str(numel(imgsTrain.Files))]);
disp(['Number of validation images: ', num2str(numel(imgsValidation.Files))]);
net = googlenet;
layers = net.Layers;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
numClasses = numel(categories(imgsTrain.Labels));
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimgsTrain = augmentedImageDatastore(inputSize(1:2),imgsTrain, ...
'DataAugmentation',imageAugmenter);
augimgsValidation = augmentedImageDatastore(inputSize(1:2),imgsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',90, ...
'MaxEpochs',1, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimgsValidation, ...
'ValidationFrequency',3, ...
'Verbose',true, ...
'ExecutionEnvironment', 'parallel', ...
'Plots','training-progress');
netTransfer = trainNetwork(augimgsTrain,lgraph,options);
trueLabels = imgsValidation.Labels;
[YPred, probs] = classify(netTransfer, augimgsValidation);
accuracy = mean(YPred == trueLabels);
plotconfusion(trueLabels, YPred);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment