Skip to content

Instantly share code, notes, and snippets.

@rubenhorn
Last active August 16, 2020 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rubenhorn/677b0ccafc635d17d688e730000ee381 to your computer and use it in GitHub Desktop.
Save rubenhorn/677b0ccafc635d17d688e730000ee381 to your computer and use it in GitHub Desktop.
Neural network digit recognition example with GUI
# Based on https://www.tensorflow.org/datasets/keras_example
# Requires numpy, pygame, tensorflow and tensorflow-datasets
NUMBER_OF_EPOCHS = 8
import numpy as np
import pygame
pygame.init()
pygame.display.set_caption('MNIST digits')
IMAGE_SIZE = 28
PIXEL_SIZE = 10
LINE_HEIGHT = 25
font = pygame.font.SysFont(None, 28)
def draw_text(text, x, y):
global font, screen
screen.blit(font.render(text, True, (0, 0, 0)), (x, y))
pygame.display.update()
# draw window
window_size = (IMAGE_SIZE * PIXEL_SIZE, IMAGE_SIZE * PIXEL_SIZE + LINE_HEIGHT * 3)
screen = pygame.display.set_mode(window_size)
screen.fill((125, 125, 125))
draw_text('1. Importing tensorflow...', 5, 5)
import tensorflow as tf # requires tf2
import tensorflow_datasets as tfds
# load dataset
draw_text('2. Importing mnist...', 5, 5 + LINE_HEIGHT)
(train, test), info = tfds.load(
'mnist',
split = ['train', 'test'],
shuffle_files = True,
as_supervised = True,
with_info = True
)
image = list(test.take(1).as_numpy_iterator())[0][0]
def draw_image(image):
for x in range(image.shape[1]):
for y in range(image.shape[0]):
pixel = image[y, x, 0]
color = (pixel, pixel, pixel)
rect = (x * PIXEL_SIZE, y * PIXEL_SIZE, PIXEL_SIZE, PIXEL_SIZE)
pygame.draw.rect(screen, color, rect)
pygame.display.update()
draw_text('3. Preparing dataset...', 5, 5 + LINE_HEIGHT * 2)
# uint8 -> float32
normalize_img = lambda image, label: (tf.cast(image, tf.float32) / 255.0, label)
# build training pipeline
train = train.map(normalize_img, num_parallel_calls = tf.data.experimental.AUTOTUNE)
train = train.cache()
train = train.shuffle(info.splits['train'].num_examples)
train = train.batch(128)
train = train.prefetch(tf.data.experimental.AUTOTUNE)
# build evaluation pipeline
test = test.map(normalize_img, num_parallel_calls = tf.data.experimental.AUTOTUNE)
test = test.batch(128)
test = test.cache()
test = test.prefetch(tf.data.experimental.AUTOTUNE)
# define model
inputs = tf.keras.Input(shape = (IMAGE_SIZE, IMAGE_SIZE, 1))
flatten = tf.keras.layers.Flatten()(inputs)
hidden = tf.keras.layers.Dense(128, activation = 'relu')(flatten)
outputs = tf.keras.layers.Dense(10, activation = 'softmax')(flatten)
model = tf.keras.Model(inputs = inputs, outputs = outputs)
model.compile(
loss = 'sparse_categorical_crossentropy',
optimizer = tf.keras.optimizers.Adam(0.001),
metrics = ['accuracy']
)
# train
draw_text('4. Training model...', 5, 5 + LINE_HEIGHT * 3)
model.fit(train, epochs = NUMBER_OF_EPOCHS, validation_data = test)
draw_text('Press c to clear the canvas', 5, 5 + IMAGE_SIZE * PIXEL_SIZE + LINE_HEIGHT)
draw_text('Press p to update prediction', 5, 5 + IMAGE_SIZE * PIXEL_SIZE + LINE_HEIGHT * 2)
def predict():
global screen, image, IMAGE_SIZE, PIXEL_SIZE
x = np.array([normalize_img(image, 0)[0]])
predictions = model.predict(x)[0]
predicted_label = np.argmax(predictions)
confidence = int(predictions[predicted_label] * 100)
draw_text('Predicts {} ({}%)'.format(predicted_label, confidence), 5, 5 + IMAGE_SIZE * PIXEL_SIZE)
def unpredict():
global screen, image, IMAGE_SIZE, PIXEL_SIZE
rect = (0, IMAGE_SIZE * PIXEL_SIZE, IMAGE_SIZE * PIXEL_SIZE, LINE_HEIGHT)
pygame.draw.rect(screen, (125, 125, 125), rect)
draw_image(image)
predict()
get_selected_pixel = lambda: [min(int(x / PIXEL_SIZE), IMAGE_SIZE - 1) for x in pygame.mouse.get_pos()]
is_running = True
while is_running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
is_running = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_c:
image *= 0
unpredict()
draw_image(image)
elif event.key == pygame.K_p:
unpredict()
predict()
if pygame.mouse.get_pressed()[0]:
selected_pixel = get_selected_pixel()
image[selected_pixel[1], selected_pixel[0], 0] = 255
unpredict()
draw_image(image)
pygame.quit()
exit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment