Last active
May 3, 2020 19:03
-
-
Save frankbryce/ec94377fcac54e4fcf03834551c906ae to your computer and use it in GitHub Desktop.
MinecraftHUD20200502.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "MinecraftHUD20200502.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyNxzn1ixg17aYKElKr69+ck", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "TPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/frankbryce/ec94377fcac54e4fcf03834551c906ae/minecrafthud.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z2x9vZNZgcuc", | |
"colab_type": "code", | |
"outputId": "082a50dd-f5bb-4afb-eda7-506d058edf83", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
} | |
}, | |
"source": [ | |
"# Mount Google Drive in Colab\n", | |
"from google.colab import drive\n", | |
"mount_location = '/content/drive/'\n", | |
"drive.mount(mount_location, force_remount=True)\n", | |
"\n", | |
"# DIRECTORY CONTAINING IMAGES TO BE LABELED\n", | |
"directory = 'My Drive/Data/Minecraft/HUD' # CHANGE IF NECESSARY\n", | |
"path_prefix = mount_location + directory + '/'" | |
], | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Mounted at /content/drive/\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "L4LXu4Qft_Ln", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Create helper methods for accessing and creating label and image files\n", | |
"import os\n", | |
"import re\n", | |
"\n", | |
"# Get list of image file names and label file names\n", | |
"img_ext = \".png\"\n", | |
"lbl_ext = \".txt\"\n", | |
"img_regex = re.compile(r'.*image([0-9]+)\\.png')\n", | |
"lbl_regex = re.compile(r'.*label([0-9]+)\\.txt')\n", | |
"\n", | |
"def getFileIdx(f):\n", | |
" # getFileIdx returns the index of the image (in string form)\n", | |
" #\n", | |
" # Example: getFileIdx(\"image001.png\") returns \"001\"\n", | |
" assert(isinstance(f, str))\n", | |
" m = img_regex.match(f)\n", | |
" if not m:\n", | |
" m = lbl_regex.match(f)\n", | |
" if not m:\n", | |
" raise Exception(\"{filename} didn't match {img_regex} or {lbl_regex}\".format(\n", | |
" img_regex=img_regex,\n", | |
" lbl_regex=lbl_regex,\n", | |
" filename=f))\n", | |
" idx = str(m.group(1))\n", | |
" if not idx:\n", | |
" raise Exception('No index found before file extension for {filename}',\n", | |
" filename=f)\n", | |
" return idx\n", | |
"\n", | |
"def makeLabelFilename(idx):\n", | |
" # makeLabelFilename returns the label filename for a give index string.\n", | |
" # \n", | |
" # Example: makeLabelFilename(\"001\") returns \"label001.txt\"\n", | |
" assert(isinstance(idx, str))\n", | |
" return 'label{idx}'.format(idx=idx) + lbl_ext\n", | |
"\n", | |
"def makeImageFilename(idx):\n", | |
" # makeImageFilename returns the image filename for a give index string.\n", | |
" # \n", | |
" # Example: makeImageFilename(\"001\") returns \"image001.txt\"\n", | |
" assert(isinstance(idx, str))\n", | |
" return 'image{idx}'.format(idx=idx) + img_ext" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WvACfV7TBFEd", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"import tensorflow_datasets.public_api as tfds\n", | |
"\n", | |
"class MinecraftHUD(tfds.core.GeneratorBasedBuilder):\n", | |
" \"\"\"Dataset fo the Minecraft overlay icons.\"\"\"\n", | |
"\n", | |
" VERSION = tfds.core.Version('0.1.1')\n", | |
"\n", | |
" def _info(self):\n", | |
" lbls = set()\n", | |
" for f in os.listdir(path_prefix):\n", | |
" if lbl_regex.match(f):\n", | |
" with open(os.path.join(path_prefix, f)) as lbl_file:\n", | |
" lbls.add(lbl_file.readline())\n", | |
" return tfds.core.DatasetInfo(\n", | |
" builder=self,\n", | |
" description=(\"This is the dataset for Minecraft Icons on the overlay \"\n", | |
" \"of the player's screen. It currently contains empty, \"\n", | |
" \"half-full and full heart and hunger icons. More still to \"\n", | |
" \"come (experience level & progress to next level, item \"\n", | |
" \"icons, F3 screen readouts, etc.). This is only to you if \"\n", | |
" \"you go to Google Drive, add \"\n", | |
" \"https://drive.google.com/drive/folders/1-AuVxiOW40G_7qeANTuVtkIqWuDrVg5T?usp=drive_open \"\n", | |
" \"to your drive, and modify the path in your own colab.\"),\n", | |
" features=tfds.features.FeaturesDict({\n", | |
" \"image\": tfds.features.Image(dtype=tf.uint16,\n", | |
" encoding_format='png'),\n", | |
" \"label\": tfds.features.ClassLabel(names=list(lbls)),\n", | |
" }),\n", | |
" supervised_keys=(\"image\", \"label\"),\n", | |
" homepage=\"https://robotproctor.wordpress.com/2020/05/03/minecraft-hud-dataset/\",\n", | |
" # Bibtex citation for the dataset\n", | |
" citation=r\"\"\"@article{minecraft-hud-dataset-2020,\n", | |
" author = {Carpenter, John},\"}\"\"\",\n", | |
" )\n", | |
"\n", | |
" def _split_generators(self, dl_manager):\n", | |
" return [\n", | |
" tfds.core.SplitGenerator(\n", | |
" name=tfds.Split.TRAIN,\n", | |
" gen_kwargs={\n", | |
" \"dir_path\": os.path.join(path_prefix, \"train\"),\n", | |
" },\n", | |
" ),\n", | |
" tfds.core.SplitGenerator(\n", | |
" name=tfds.Split.TEST,\n", | |
" gen_kwargs={\n", | |
" \"dir_path\": os.path.join(path_prefix, \"test\"),\n", | |
" },\n", | |
" ),\n", | |
" ]\n", | |
"\n", | |
" def _generate_examples(self, dir_path):\n", | |
" for f in os.listdir(dir_path):\n", | |
" if img_regex.match(f):\n", | |
" idx = getFileIdx(f)\n", | |
" lbl_filename = makeLabelFilename(idx)\n", | |
" with open(os.path.join(dir_path, lbl_filename)) as lbl_file:\n", | |
" lbl = lbl_file.readline()\n", | |
" yield idx, {\n", | |
" \"image\": str(os.path.join(dir_path, f)),\n", | |
" \"label\": lbl,\n", | |
" }\n", | |
"\n", | |
"minecraft_hud_load_to_globals = MinecraftHUD()\n", | |
"minecraft_hud_builder = tfds.builder(\"minecraft_hud\")\n", | |
"minecraft_hud_info = minecraft_hud_builder.info\n", | |
"minecraft_hud_builder.download_and_prepare()\n", | |
"datasets = minecraft_hud_builder.as_dataset()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3G56c4x12qSO", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import random\n", | |
"\n", | |
"(raw_train, raw_test), info = tfds.load(\n", | |
" \"minecraft_hud\", \n", | |
" split=['train', 'test'],\n", | |
" with_info=True, as_supervised=True)\n", | |
"\n", | |
"def format_example(image, label):\n", | |
" image = tf.cast(image, tf.float32)\n", | |
" image = (image/65535.0)\n", | |
" image = tf.image.resize_with_crop_or_pad(image, 32, 32)\n", | |
" return image, label\n", | |
"\n", | |
"train = raw_train.map(format_example)\n", | |
"test = raw_test.map(format_example)\n", | |
"\n", | |
"random.seed()\n", | |
"train_batches = train.shuffle(1000).batch(32)\n", | |
"test_batches = test.batch(32)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RBBk3Ner4XMX", | |
"colab_type": "code", | |
"colab": {}, | |
"cellView": "form" | |
}, | |
"source": [ | |
"#@title Load Model from Save\n", | |
"from tensorflow.keras import layers, models\n", | |
"\n", | |
"load_model_from_file = False #@param {type:\"boolean\"}\n", | |
"load_file_name = 'minecraft_hud_model' #@param {type:\"string\"}\n", | |
"\n", | |
"if load_model_from_file:\n", | |
" model = models.load_model(path_prefix + load_file_name)\n", | |
"else:\n", | |
" model = models.Sequential()\n", | |
" model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))\n", | |
" model.add(layers.MaxPooling2D((2, 2)))\n", | |
" model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", | |
" model.add(layers.MaxPooling2D((2, 2)))\n", | |
" model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", | |
"\n", | |
" model.add(layers.Flatten())\n", | |
" model.add(layers.Dense(64, activation='relu'))\n", | |
" model.add(layers.Dense(info.features['label'].num_classes))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SxYO8Mqt4ptG", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 243 | |
}, | |
"cellView": "both", | |
"outputId": "7b6892be-8fc8-4e6a-83c3-6bd071d95e38" | |
}, | |
"source": [ | |
"#@title Train the Model\n", | |
"\n", | |
"learning_rate = 0.0001 #@param {type:\"number\"}\n", | |
"num_epochs = 5 #@param {type:\"integer\"}\n", | |
"\n", | |
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n", | |
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | |
" metrics=['accuracy'])\n", | |
"history = model.fit(train_batches,\n", | |
" epochs=num_epochs,\n", | |
" validation_data=test_batches)\n", | |
" \n", | |
"save_model_to_file = True #@param {type:\"boolean\"}\n", | |
"save_file_name = 'minecraft_hud_model' #@param {type:\"string\"}\n", | |
"\n", | |
"if save_model_to_file:\n", | |
" model.save(path_prefix + save_file_name) " | |
], | |
"execution_count": 112, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/5\n", | |
"5/5 [==============================] - 1s 123ms/step - loss: 0.0149 - accuracy: 0.9923 - val_loss: 0.0099 - val_accuracy: 1.0000\n", | |
"Epoch 2/5\n", | |
"5/5 [==============================] - 1s 105ms/step - loss: 0.0132 - accuracy: 0.9923 - val_loss: 0.0063 - val_accuracy: 1.0000\n", | |
"Epoch 3/5\n", | |
"5/5 [==============================] - 1s 113ms/step - loss: 0.0052 - accuracy: 1.0000 - val_loss: 0.0047 - val_accuracy: 1.0000\n", | |
"Epoch 4/5\n", | |
"5/5 [==============================] - 1s 111ms/step - loss: 0.0043 - accuracy: 1.0000 - val_loss: 0.0042 - val_accuracy: 1.0000\n", | |
"Epoch 5/5\n", | |
"5/5 [==============================] - 1s 110ms/step - loss: 0.0037 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000\n", | |
"INFO:tensorflow:Assets written to: /content/drive/My Drive/Data/Minecraft/HUD/minecraft_hud_model/assets\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: /content/drive/My Drive/Data/Minecraft/HUD/minecraft_hud_model/assets\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EWoFWhEQoYbo", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "604bc445-f514-4343-8911-389bde2661aa" | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"lbls = []\n", | |
"predictions = []\n", | |
"bad_preds = []\n", | |
"for (img, lbl) in test.batch(1):\n", | |
" logits = model(img, training=False)\n", | |
" prediction = int(tf.argmax(logits, axis=1, output_type=tf.int32)[0])\n", | |
" if prediction != lbl:\n", | |
" bad_preds.append((img[0], lbl, prediction))\n", | |
" predictions.append(prediction)\n", | |
" lbls.append(lbl)\n", | |
"\n", | |
"tf.math.confusion_matrix(lbls, predictions)\n", | |
"\n", | |
"fig = plt.figure()\n", | |
"for i in range(len(bad_preds)):\n", | |
" fig.add_subplot(1,len(bad_preds),i+1)\n", | |
" plt.imshow(bad_preds[i][0])\n", | |
" plt.title(\"lbl: '{}'\\npred: '{}'\".format(\n", | |
" info.features['label'].names[int(bad_preds[i][1])],\n", | |
" info.features['label'].names[int(bad_preds[i][2])]))\n", | |
" plt.axis('off')" | |
], | |
"execution_count": 114, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 0 Axes>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment