Skip to content

Instantly share code, notes, and snippets.

@frankbryce
Last active May 3, 2020 19:03
Show Gist options
  • Save frankbryce/ec94377fcac54e4fcf03834551c906ae to your computer and use it in GitHub Desktop.
Save frankbryce/ec94377fcac54e4fcf03834551c906ae to your computer and use it in GitHub Desktop.
MinecraftHUD20200502.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MinecraftHUD20200502.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNxzn1ixg17aYKElKr69+ck",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/frankbryce/ec94377fcac54e4fcf03834551c906ae/minecrafthud.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z2x9vZNZgcuc",
"colab_type": "code",
"outputId": "082a50dd-f5bb-4afb-eda7-506d058edf83",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
}
},
"source": [
"# Mount Google Drive in Colab\n",
"from google.colab import drive\n",
"mount_location = '/content/drive/'\n",
"drive.mount(mount_location, force_remount=True)\n",
"\n",
"# DIRECTORY CONTAINING IMAGES TO BE LABELED\n",
"directory = 'My Drive/Data/Minecraft/HUD' # CHANGE IF NECESSARY\n",
"path_prefix = mount_location + directory + '/'"
],
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"text": [
"Mounted at /content/drive/\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "L4LXu4Qft_Ln",
"colab_type": "code",
"colab": {}
},
"source": [
"# Create helper methods for accessing and creating label and image files\n",
"import os\n",
"import re\n",
"\n",
"# Get list of image file names and label file names\n",
"img_ext = \".png\"\n",
"lbl_ext = \".txt\"\n",
"img_regex = re.compile(r'.*image([0-9]+)\\.png')\n",
"lbl_regex = re.compile(r'.*label([0-9]+)\\.txt')\n",
"\n",
"def getFileIdx(f):\n",
" # getFileIdx returns the index of the image (in string form)\n",
" #\n",
" # Example: getFileIdx(\"image001.png\") returns \"001\"\n",
" assert(isinstance(f, str))\n",
" m = img_regex.match(f)\n",
" if not m:\n",
" m = lbl_regex.match(f)\n",
" if not m:\n",
" raise Exception(\"{filename} didn't match {img_regex} or {lbl_regex}\".format(\n",
" img_regex=img_regex,\n",
" lbl_regex=lbl_regex,\n",
" filename=f))\n",
" idx = str(m.group(1))\n",
" if not idx:\n",
" raise Exception('No index found before file extension for {filename}',\n",
" filename=f)\n",
" return idx\n",
"\n",
"def makeLabelFilename(idx):\n",
" # makeLabelFilename returns the label filename for a give index string.\n",
" # \n",
" # Example: makeLabelFilename(\"001\") returns \"label001.txt\"\n",
" assert(isinstance(idx, str))\n",
" return 'label{idx}'.format(idx=idx) + lbl_ext\n",
"\n",
"def makeImageFilename(idx):\n",
" # makeImageFilename returns the image filename for a give index string.\n",
" # \n",
" # Example: makeImageFilename(\"001\") returns \"image001.txt\"\n",
" assert(isinstance(idx, str))\n",
" return 'image{idx}'.format(idx=idx) + img_ext"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WvACfV7TBFEd",
"colab_type": "code",
"colab": {}
},
"source": [
"import tensorflow as tf\n",
"import tensorflow_datasets.public_api as tfds\n",
"\n",
"class MinecraftHUD(tfds.core.GeneratorBasedBuilder):\n",
" \"\"\"Dataset fo the Minecraft overlay icons.\"\"\"\n",
"\n",
" VERSION = tfds.core.Version('0.1.1')\n",
"\n",
" def _info(self):\n",
" lbls = set()\n",
" for f in os.listdir(path_prefix):\n",
" if lbl_regex.match(f):\n",
" with open(os.path.join(path_prefix, f)) as lbl_file:\n",
" lbls.add(lbl_file.readline())\n",
" return tfds.core.DatasetInfo(\n",
" builder=self,\n",
" description=(\"This is the dataset for Minecraft Icons on the overlay \"\n",
" \"of the player's screen. It currently contains empty, \"\n",
" \"half-full and full heart and hunger icons. More still to \"\n",
" \"come (experience level & progress to next level, item \"\n",
" \"icons, F3 screen readouts, etc.). This is only to you if \"\n",
" \"you go to Google Drive, add \"\n",
" \"https://drive.google.com/drive/folders/1-AuVxiOW40G_7qeANTuVtkIqWuDrVg5T?usp=drive_open \"\n",
" \"to your drive, and modify the path in your own colab.\"),\n",
" features=tfds.features.FeaturesDict({\n",
" \"image\": tfds.features.Image(dtype=tf.uint16,\n",
" encoding_format='png'),\n",
" \"label\": tfds.features.ClassLabel(names=list(lbls)),\n",
" }),\n",
" supervised_keys=(\"image\", \"label\"),\n",
" homepage=\"https://robotproctor.wordpress.com/2020/05/03/minecraft-hud-dataset/\",\n",
" # Bibtex citation for the dataset\n",
" citation=r\"\"\"@article{minecraft-hud-dataset-2020,\n",
" author = {Carpenter, John},\"}\"\"\",\n",
" )\n",
"\n",
" def _split_generators(self, dl_manager):\n",
" return [\n",
" tfds.core.SplitGenerator(\n",
" name=tfds.Split.TRAIN,\n",
" gen_kwargs={\n",
" \"dir_path\": os.path.join(path_prefix, \"train\"),\n",
" },\n",
" ),\n",
" tfds.core.SplitGenerator(\n",
" name=tfds.Split.TEST,\n",
" gen_kwargs={\n",
" \"dir_path\": os.path.join(path_prefix, \"test\"),\n",
" },\n",
" ),\n",
" ]\n",
"\n",
" def _generate_examples(self, dir_path):\n",
" for f in os.listdir(dir_path):\n",
" if img_regex.match(f):\n",
" idx = getFileIdx(f)\n",
" lbl_filename = makeLabelFilename(idx)\n",
" with open(os.path.join(dir_path, lbl_filename)) as lbl_file:\n",
" lbl = lbl_file.readline()\n",
" yield idx, {\n",
" \"image\": str(os.path.join(dir_path, f)),\n",
" \"label\": lbl,\n",
" }\n",
"\n",
"minecraft_hud_load_to_globals = MinecraftHUD()\n",
"minecraft_hud_builder = tfds.builder(\"minecraft_hud\")\n",
"minecraft_hud_info = minecraft_hud_builder.info\n",
"minecraft_hud_builder.download_and_prepare()\n",
"datasets = minecraft_hud_builder.as_dataset()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3G56c4x12qSO",
"colab_type": "code",
"colab": {}
},
"source": [
"import random\n",
"\n",
"(raw_train, raw_test), info = tfds.load(\n",
" \"minecraft_hud\", \n",
" split=['train', 'test'],\n",
" with_info=True, as_supervised=True)\n",
"\n",
"def format_example(image, label):\n",
" image = tf.cast(image, tf.float32)\n",
" image = (image/65535.0)\n",
" image = tf.image.resize_with_crop_or_pad(image, 32, 32)\n",
" return image, label\n",
"\n",
"train = raw_train.map(format_example)\n",
"test = raw_test.map(format_example)\n",
"\n",
"random.seed()\n",
"train_batches = train.shuffle(1000).batch(32)\n",
"test_batches = test.batch(32)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RBBk3Ner4XMX",
"colab_type": "code",
"colab": {},
"cellView": "form"
},
"source": [
"#@title Load Model from Save\n",
"from tensorflow.keras import layers, models\n",
"\n",
"load_model_from_file = False #@param {type:\"boolean\"}\n",
"load_file_name = 'minecraft_hud_model' #@param {type:\"string\"}\n",
"\n",
"if load_model_from_file:\n",
" model = models.load_model(path_prefix + load_file_name)\n",
"else:\n",
" model = models.Sequential()\n",
" model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))\n",
" model.add(layers.MaxPooling2D((2, 2)))\n",
" model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n",
" model.add(layers.MaxPooling2D((2, 2)))\n",
" model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n",
"\n",
" model.add(layers.Flatten())\n",
" model.add(layers.Dense(64, activation='relu'))\n",
" model.add(layers.Dense(info.features['label'].num_classes))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SxYO8Mqt4ptG",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 243
},
"cellView": "both",
"outputId": "7b6892be-8fc8-4e6a-83c3-6bd071d95e38"
},
"source": [
"#@title Train the Model\n",
"\n",
"learning_rate = 0.0001 #@param {type:\"number\"}\n",
"num_epochs = 5 #@param {type:\"integer\"}\n",
"\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"history = model.fit(train_batches,\n",
" epochs=num_epochs,\n",
" validation_data=test_batches)\n",
" \n",
"save_model_to_file = True #@param {type:\"boolean\"}\n",
"save_file_name = 'minecraft_hud_model' #@param {type:\"string\"}\n",
"\n",
"if save_model_to_file:\n",
" model.save(path_prefix + save_file_name) "
],
"execution_count": 112,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"5/5 [==============================] - 1s 123ms/step - loss: 0.0149 - accuracy: 0.9923 - val_loss: 0.0099 - val_accuracy: 1.0000\n",
"Epoch 2/5\n",
"5/5 [==============================] - 1s 105ms/step - loss: 0.0132 - accuracy: 0.9923 - val_loss: 0.0063 - val_accuracy: 1.0000\n",
"Epoch 3/5\n",
"5/5 [==============================] - 1s 113ms/step - loss: 0.0052 - accuracy: 1.0000 - val_loss: 0.0047 - val_accuracy: 1.0000\n",
"Epoch 4/5\n",
"5/5 [==============================] - 1s 111ms/step - loss: 0.0043 - accuracy: 1.0000 - val_loss: 0.0042 - val_accuracy: 1.0000\n",
"Epoch 5/5\n",
"5/5 [==============================] - 1s 110ms/step - loss: 0.0037 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000\n",
"INFO:tensorflow:Assets written to: /content/drive/My Drive/Data/Minecraft/HUD/minecraft_hud_model/assets\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /content/drive/My Drive/Data/Minecraft/HUD/minecraft_hud_model/assets\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EWoFWhEQoYbo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "604bc445-f514-4343-8911-389bde2661aa"
},
"source": [
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"\n",
"lbls = []\n",
"predictions = []\n",
"bad_preds = []\n",
"for (img, lbl) in test.batch(1):\n",
" logits = model(img, training=False)\n",
" prediction = int(tf.argmax(logits, axis=1, output_type=tf.int32)[0])\n",
" if prediction != lbl:\n",
" bad_preds.append((img[0], lbl, prediction))\n",
" predictions.append(prediction)\n",
" lbls.append(lbl)\n",
"\n",
"tf.math.confusion_matrix(lbls, predictions)\n",
"\n",
"fig = plt.figure()\n",
"for i in range(len(bad_preds)):\n",
" fig.add_subplot(1,len(bad_preds),i+1)\n",
" plt.imshow(bad_preds[i][0])\n",
" plt.title(\"lbl: '{}'\\npred: '{}'\".format(\n",
" info.features['label'].names[int(bad_preds[i][1])],\n",
" info.features['label'].names[int(bad_preds[i][2])]))\n",
" plt.axis('off')"
],
"execution_count": 114,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 0 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment