-
-
Save wocks1123/90e711f11067982e8c57e4a501329ac6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* 텐서플로2와 케라스로 구현하는 딥러닝 2/e Chapter13 | |
*/ | |
function getModel() { | |
const IMAGE_WIDTH = 28; | |
const IMAGE_HEIGHT = 28; | |
const IMAGE_CHANNELS = 1; | |
const NUM_OUTPUT_CLASSES = 10; | |
const model = tf.sequential(); | |
model.add( | |
tf.layers.conv2d( | |
{ | |
inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS], | |
kernelSize: 3, | |
filters: 32, | |
strides: 1, | |
activation: "relu", | |
kernelInitializer: "heUniform" | |
} | |
) | |
); | |
model.add( | |
tf.layers.maxPooling2d( | |
{ | |
poolSize: [2, 2], | |
strides: [2, 2] | |
} | |
) | |
); | |
model.add( | |
tf.layers.conv2d( | |
{ | |
kernelSize: 3, | |
filters: 64, | |
strides: 1, | |
activation: "relu", | |
kernelInitializer: "heUniform" | |
} | |
) | |
); | |
model.add( | |
tf.layers.maxPooling2d( | |
{ | |
poolSize: [2, 2], | |
strides: [2, 2] | |
} | |
) | |
); | |
model.add( | |
tf.layers.conv2d( | |
{ | |
kernelSize: 3, | |
filters: 128, | |
strides: 1, | |
activation: "relu", | |
kernelInitializer: "heUniform" | |
} | |
) | |
); | |
model.add( | |
tf.layers.maxPooling2d( | |
{ | |
poolSize: [2, 2], | |
strides: [2, 2] | |
} | |
) | |
); | |
model.add(tf.layers.flatten()); | |
model.add( | |
tf.layers.dense( | |
{ | |
units: NUM_OUTPUT_CLASSES, | |
kernelInitializer: "heUniform", | |
activation: "softmax" | |
} | |
) | |
); | |
const optimizer = tf.train.adam(0.001); | |
model.compile({ | |
optimizer: optimizer, | |
loss: "categoricalCrossentropy", | |
metrics: ["accuracy"] | |
}); | |
return model; | |
} | |
async function train(model, data) | |
{ | |
const metrics = ["loss", "val_loss", "acc", "val_acc"]; | |
const container = { | |
name: "model Training", | |
styles: | |
{ | |
height:'10000px' | |
} | |
}; | |
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics); | |
const BATCH_SIZE = 256; | |
const TRAIN_DATA_SIZE = 5500; | |
const TEST_DATA_SIZE = 1000; | |
const [trainXs, trainYs] = tf.tidy(()=> { | |
const d = data.nextTrainBatch(TRAIN_DATA_SIZE); | |
return [ | |
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]), | |
d.labels | |
]; | |
}); | |
const [testXs, testYs] = tf.tidy(()=> { | |
const d = data.nextTrainBatch(TEST_DATA_SIZE); | |
return [ | |
d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]), | |
d.labels | |
]; | |
}); | |
return model.fit(trainXs, trainYs, { | |
batchSize: BATCH_SIZE, | |
validationData: [testXs, testYs], | |
epochs: 10, | |
shuffle: true, | |
callbacks: fitCallbacks | |
}); | |
} | |
const classNames = [ | |
"Zero", "One", "Two", "Three", "Four", | |
"Five", "Six", "Seven", "Eight", "Nine" | |
]; | |
function doPrediction(model, data, testDataSize=500) | |
{ | |
const IMAGE_WIDTH = 28; | |
const IMAGE_HEIGHT = 28; | |
const testData = data.nextTestBatch(testDataSize); | |
const testxs = testData.xs.reshape( | |
[testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1] | |
); | |
const labels = testData.labels.argMax([-1]); | |
const preds = model.predict(testxs).argMax([-1]); | |
testxs.dispose(); | |
return [preds, labels]; | |
} | |
async function showAccuracy(model, data) | |
{ | |
const [preds, labels] = doPrediction(model, data); | |
console.log(labels, preds); | |
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds); | |
const container = { | |
name: "Accuracy", | |
tab: "Evaluation" | |
}; | |
tfvis.show.perClassAccuracy(container, classAccuracy, classNames); | |
labels.dispose(); | |
} | |
async function showConfusion(model, data) | |
{ | |
const [preds, labels] = doPrediction(model, data); | |
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds); | |
const container = { | |
name: "Confusion Matrix", | |
tab: "Evaluation" | |
}; | |
tfvis.render.confusionMatrix( | |
container, | |
{ | |
values: confusionMatrix | |
}, | |
classNames | |
); | |
labels.dispose(); | |
} | |
import {MnistData} from "./data.js"; | |
async function showExamples(data) | |
{ | |
// Create a container in the visor | |
const surface = tfvis | |
.visor() | |
.surface({ name: 'Input Data Examples', tab: 'Input Data' }); | |
// Get the examples | |
const examples = data.nextTestBatch(20); | |
const numExamples = examples.xs.shape[0]; | |
// Create a canvas element to render each example | |
for (let i = 0; i < numExamples; i++) { | |
const imageTensor = tf.tidy(() => { | |
// Reshape the image to 28x28 px | |
return examples.xs | |
.slice([i, 0], [1, examples.xs.shape[1]]) | |
.reshape([28, 28, 1]); | |
}); | |
const canvas = document.createElement('canvas'); | |
canvas.width = 28; | |
canvas.height = 28; | |
canvas.style = 'margin: 4px;'; | |
await tf.browser.toPixels(imageTensor, canvas); | |
surface.drawArea.appendChild(canvas); | |
imageTensor.dispose(); | |
} | |
} | |
async function run() | |
{ | |
const data = new MnistData(); | |
await data.load(); | |
await showExamples(data); | |
const model = getModel(); | |
tfvis.show.modelSummary( | |
{ | |
name: "Model Architecture" | |
}, | |
model | |
); | |
await train(model, data); | |
await showAccuracy(model, data); | |
await showConfusion(model, data); | |
} | |
document.addEventListener("DOMContentLoaded", run); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment