-
-
Save shindooo/d9c2ba9cd35720400ec1bdfd531fcabd to your computer and use it in GitHub Desktop.
conversion_to_type
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "conversion_to_type", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shindooo/d9c2ba9cd35720400ec1bdfd531fcabd/conversion_to_type.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZWv3im0xDHQs", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from google.colab import drive \n", | |
"drive.mount('/content/drive')\n", | |
"%cd 'drive/My Drive/qiita'\n", | |
"!ls" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "H4kTTCbNgFAj", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!pip install emnist" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "88oFW-zStMLO", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from PIL import Image, ImageDraw, ImageFont\n", | |
"import matplotlib.pyplot as plt\n", | |
"import matplotlib.font_manager as font_manager\n", | |
"import cv2\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from tqdm import tqdm_notebook as tqdm\n", | |
"\n", | |
"from keras.layers import Reshape, Conv2D\n", | |
"from keras.preprocessing.image import ImageDataGenerator\n", | |
"from keras.models import Model, load_model\n", | |
"from keras.applications.resnet50 import ResNet50\n", | |
"import keras.optimizers as optimizers\n", | |
"import keras.losses as losses\n", | |
"import keras.callbacks\n", | |
"\n", | |
"# EMNIST\n", | |
"import emnist\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "H9OVJAqkyFzX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# 各クラスを文字に直したもの\n", | |
"chars = [chr(i) for i in range(48, 48+10)] + [chr(i) for i in range(65, 65+26)] + [chr(i) for i in range(97, 97+26)]\n", | |
"num_classes = len(chars)\n", | |
"print(chars)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xvC3qVUmu_Io", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"font_manager.findSystemFonts(fontpaths=None, fontext='ttf')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "W71pLkRp0Xkf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def get_char_size(char, font):\n", | |
" testImg = Image.new('RGB', (1, 1))\n", | |
" testDraw = ImageDraw.Draw(testImg)\n", | |
" return testDraw.textsize(char, font)\n", | |
"\n", | |
"def get_char_size_max(font):\n", | |
" max_width, max_height = 1, 1\n", | |
" for char in chars:\n", | |
" width, height = get_char_size(char, font)\n", | |
" max_width, max_height = max(width, max_width), max(height, max_height)\n", | |
" return max_width, max_height\n", | |
"\n", | |
"def string_to_img_array(text):\n", | |
" \"\"\"\n", | |
" 文字列を画像に描画しnumpy配列に変換する\n", | |
" \"\"\"\n", | |
"\n", | |
" font_path = '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf'\n", | |
" fontsize = 28\n", | |
" font = ImageFont.truetype(font_path, fontsize)\n", | |
"\n", | |
" # width, height = get_char_size_max(font)\n", | |
" # print(width, height)\n", | |
" width, height = 32, 32\n", | |
"\n", | |
" colorText = \"black\"\n", | |
" colorBackground = \"white\"\n", | |
" colorOutline = \"white\"\n", | |
"\n", | |
" img = Image.new('L', (width, height), colorBackground)\n", | |
" d = ImageDraw.Draw(img)\n", | |
" d.text((1, 1), text, fill=colorText, font=font)\n", | |
" d.rectangle((0, 0, width, height), outline=colorOutline)\n", | |
" img_array = np.array(img)\n", | |
"\n", | |
" return img_array" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AvSMACB87zdp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Keras組み込みのResNet50に入力できるよう便宜的に変換\n", | |
"\n", | |
"def convet_x_for_resnet50(X):\n", | |
" X = X.reshape(X.shape + (1,))\n", | |
" X = np.array([cv2.resize(x, dsize=(32, 32), interpolation=cv2.INTER_CUBIC) for x in X])\n", | |
" return np.stack((X,) * 3, axis=-1)\n", | |
"\n", | |
"def convet_y_for_resnet50(Y):\n", | |
" # 正解画像の作成もここで実行\n", | |
" Y = np.array([string_to_img_array(chars[y]) for y in tqdm(Y)] )\n", | |
" return np.stack((Y,) * 3, axis=-1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "m6YEmsnbl6wx", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"emnist.list_datasets()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "r8xBlvFCCcH0", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# EMNISTをロード\n", | |
"train_X, train_Y = emnist.extract_training_samples('byclass')\n", | |
"train_X, train_Y = train_X[:100000], train_Y[:100000]\n", | |
"test_X, test_Y = emnist.extract_test_samples('byclass')\n", | |
"test_X, test_Y = test_X[:10000], test_Y[:10000]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "pujNpIjGnhPJ", | |
"colab": {} | |
}, | |
"source": [ | |
"# Keras組み込みのResNet50に入力できるよう便宜的に変換(+正規化)\n", | |
"train_X = convet_x_for_resnet50(train_X) / 255\n", | |
"train_Y = convet_y_for_resnet50(train_Y) / 255\n", | |
"test_X = convet_x_for_resnet50(test_X) / 255\n", | |
"test_Y = convet_y_for_resnet50(test_Y) / 255" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-o_iBHo9CK97", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def reshape_for_sample_show(data):\n", | |
" data = data.reshape(-1,8,32,32,3)\n", | |
" data = data.transpose(0,2,1,3,4)\n", | |
" return data.reshape(8*32,8*32,3)\n", | |
"\n", | |
"sample_X = reshape_for_sample_show(train_X[:64])\n", | |
"sample_Y = reshape_for_sample_show(train_Y[:64])\n", | |
"plt.figure(figsize=(16,8))\n", | |
"plt.subplot(1, 2, 1)\n", | |
"plt.title('Input')\n", | |
"plt.imshow(sample_X)\n", | |
"plt.axis('off')\n", | |
"\n", | |
"plt.subplot(1, 2, 2)\n", | |
"plt.title('Correct')\n", | |
"plt.imshow(sample_Y)\n", | |
"plt.axis('off')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SlRLeX0SwG0L", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"#適宜調整\n", | |
"batch_size = 180\n", | |
"epochs = 96#160\n", | |
"lr = 0.0001" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3hJSlLEteUIP", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"base_model = ResNet50(weights=None, include_top=False, input_shape=train_X.shape[1:])\n", | |
"model = base_model.output\n", | |
"model = Reshape((32, 32, 2))(model)\n", | |
"\n", | |
"# 出力層\n", | |
"model = Conv2D(3, (32, 32), padding='same', activation='linear')(model)\n", | |
"\n", | |
"model = Model(inputs=base_model.input, outputs=model)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VLH6t7vSSfvP", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"model.load_weights('checkpoint.h5')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qlUrmouBvVfY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"model.compile(loss=losses.mean_squared_error,\n", | |
" optimizer=optimizers.Adam(lr=lr),\n", | |
" metrics=['mae', 'mse'])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sT3I5VaKEtI-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# 必要ならチェックポイント設定\n", | |
"checkpoint = keras.callbacks.ModelCheckpoint(filepath = 'checkpoint.h5', monitor='val_mean_squared_error', verbose=1, save_best_only=True, mode='auto')\n", | |
"cbs = [checkpoint]\n", | |
"\n", | |
"# 訓練\n", | |
"history = model.fit(train_X, train_Y, batch_size=batch_size, epochs=epochs,\n", | |
" verbose=1, validation_split=0.1, callbacks=cbs)\n", | |
"\n", | |
"# model.save('saved_model.h5')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZS1xZwE_vaJd", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# 評価\n", | |
"score = model.evaluate(test_X, test_Y, verbose=0)\n", | |
"print('Test loss:', score[0])\n", | |
"print('Test mae:', score[1])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fvUV5EKRviZj", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"hist = pd.DataFrame(history.history)\n", | |
"hist['epoch'] = history.epoch\n", | |
"\n", | |
"plt.figure()\n", | |
"\n", | |
"plt.xlabel('Epoch')\n", | |
"plt.ylabel('Mean Abs Error')\n", | |
"plt.plot(hist['epoch'], hist['mean_absolute_error'],\n", | |
" label='Train Error')\n", | |
"plt.plot(hist['epoch'], hist['val_mean_absolute_error'],\n", | |
" label = 'Val Error')\n", | |
"plt.legend()\n", | |
"\n", | |
"plt.figure()\n", | |
"\n", | |
"plt.xlabel('Epoch')\n", | |
"plt.ylabel('Mean Square Error')\n", | |
"plt.plot(hist['epoch'], hist['mean_squared_error'],\n", | |
" label='Train Error')\n", | |
"plt.plot(hist['epoch'], hist['val_mean_squared_error'],\n", | |
" label = 'Val Error')\n", | |
"plt.legend()\n", | |
"plt.show()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "coHelj6pEQst", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"model.load_weights('checkpoint.h5')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nN9b-OOUUJk0", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"predicted = model.predict(test_X[:64])\n", | |
"\n", | |
"sample_input = reshape_for_sample_show(test_X[:64])\n", | |
"sample_correct = reshape_for_sample_show(test_Y[:64])\n", | |
"sample_predicted = reshape_for_sample_show(predicted)\n", | |
"\n", | |
"plt.figure(figsize=(18,6))\n", | |
"\n", | |
"# 入力\n", | |
"plt.subplot(1, 3, 1)\n", | |
"plt.title('Input')\n", | |
"plt.imshow(sample_input)\n", | |
"plt.axis('off')\n", | |
"\n", | |
"# 出力\n", | |
"plt.subplot(1, 3, 2)\n", | |
"plt.title('Predicated')\n", | |
"plt.imshow(sample_predicted)\n", | |
"plt.axis('off')\n", | |
"\n", | |
"# 正解\n", | |
"plt.subplot(1, 3, 3)\n", | |
"plt.title('Correct')\n", | |
"plt.imshow(sample_correct)\n", | |
"plt.axis('off')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7N3B6F84r_O-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"katakana = np.expand_dims(cv2.resize(np.array(Image.open('ア.png')), dsize=(32, 32), interpolation=cv2.INTER_CUBIC), axis=0) / 255\n", | |
"predicted = model.predict(katakana)\n", | |
"plt.figure(figsize=(6,6))\n", | |
"plt.imshow(predicted[0])\n", | |
"plt.axis('off')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "n3Ox5xFypqr_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment