Skip to content

Instantly share code, notes, and snippets.

@shindooo
Created March 12, 2020 03:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shindooo/d9c2ba9cd35720400ec1bdfd531fcabd to your computer and use it in GitHub Desktop.
Save shindooo/d9c2ba9cd35720400ec1bdfd531fcabd to your computer and use it in GitHub Desktop.
conversion_to_type
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "conversion_to_type",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shindooo/d9c2ba9cd35720400ec1bdfd531fcabd/conversion_to_type.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZWv3im0xDHQs",
"colab_type": "code",
"colab": {}
},
"source": [
"from google.colab import drive \n",
"drive.mount('/content/drive')\n",
"%cd 'drive/My Drive/qiita'\n",
"!ls"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "H4kTTCbNgFAj",
"colab_type": "code",
"colab": {}
},
"source": [
"!pip install emnist"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "88oFW-zStMLO",
"colab_type": "code",
"colab": {}
},
"source": [
"from PIL import Image, ImageDraw, ImageFont\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.font_manager as font_manager\n",
"import cv2\n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm import tqdm_notebook as tqdm\n",
"\n",
"from keras.layers import Reshape, Conv2D\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.models import Model, load_model\n",
"from keras.applications.resnet50 import ResNet50\n",
"import keras.optimizers as optimizers\n",
"import keras.losses as losses\n",
"import keras.callbacks\n",
"\n",
"# EMNIST\n",
"import emnist\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "H9OVJAqkyFzX",
"colab_type": "code",
"colab": {}
},
"source": [
"# 各クラスを文字に直したもの\n",
"chars = [chr(i) for i in range(48, 48+10)] + [chr(i) for i in range(65, 65+26)] + [chr(i) for i in range(97, 97+26)]\n",
"num_classes = len(chars)\n",
"print(chars)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xvC3qVUmu_Io",
"colab_type": "code",
"colab": {}
},
"source": [
"font_manager.findSystemFonts(fontpaths=None, fontext='ttf')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "W71pLkRp0Xkf",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_char_size(char, font):\n",
" testImg = Image.new('RGB', (1, 1))\n",
" testDraw = ImageDraw.Draw(testImg)\n",
" return testDraw.textsize(char, font)\n",
"\n",
"def get_char_size_max(font):\n",
" max_width, max_height = 1, 1\n",
" for char in chars:\n",
" width, height = get_char_size(char, font)\n",
" max_width, max_height = max(width, max_width), max(height, max_height)\n",
" return max_width, max_height\n",
"\n",
"def string_to_img_array(text):\n",
" \"\"\"\n",
" 文字列を画像に描画しnumpy配列に変換する\n",
" \"\"\"\n",
"\n",
" font_path = '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf'\n",
" fontsize = 28\n",
" font = ImageFont.truetype(font_path, fontsize)\n",
"\n",
" # width, height = get_char_size_max(font)\n",
" # print(width, height)\n",
" width, height = 32, 32\n",
"\n",
" colorText = \"black\"\n",
" colorBackground = \"white\"\n",
" colorOutline = \"white\"\n",
"\n",
" img = Image.new('L', (width, height), colorBackground)\n",
" d = ImageDraw.Draw(img)\n",
" d.text((1, 1), text, fill=colorText, font=font)\n",
" d.rectangle((0, 0, width, height), outline=colorOutline)\n",
" img_array = np.array(img)\n",
"\n",
" return img_array"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AvSMACB87zdp",
"colab_type": "code",
"colab": {}
},
"source": [
"# Keras組み込みのResNet50に入力できるよう便宜的に変換\n",
"\n",
"def convet_x_for_resnet50(X):\n",
" X = X.reshape(X.shape + (1,))\n",
" X = np.array([cv2.resize(x, dsize=(32, 32), interpolation=cv2.INTER_CUBIC) for x in X])\n",
" return np.stack((X,) * 3, axis=-1)\n",
"\n",
"def convet_y_for_resnet50(Y):\n",
" # 正解画像の作成もここで実行\n",
" Y = np.array([string_to_img_array(chars[y]) for y in tqdm(Y)] )\n",
" return np.stack((Y,) * 3, axis=-1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "m6YEmsnbl6wx",
"colab_type": "code",
"colab": {}
},
"source": [
"emnist.list_datasets()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "r8xBlvFCCcH0",
"colab_type": "code",
"colab": {}
},
"source": [
"# EMNISTをロード\n",
"train_X, train_Y = emnist.extract_training_samples('byclass')\n",
"train_X, train_Y = train_X[:100000], train_Y[:100000]\n",
"test_X, test_Y = emnist.extract_test_samples('byclass')\n",
"test_X, test_Y = test_X[:10000], test_Y[:10000]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "pujNpIjGnhPJ",
"colab": {}
},
"source": [
"# Keras組み込みのResNet50に入力できるよう便宜的に変換(+正規化)\n",
"train_X = convet_x_for_resnet50(train_X) / 255\n",
"train_Y = convet_y_for_resnet50(train_Y) / 255\n",
"test_X = convet_x_for_resnet50(test_X) / 255\n",
"test_Y = convet_y_for_resnet50(test_Y) / 255"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-o_iBHo9CK97",
"colab_type": "code",
"colab": {}
},
"source": [
"def reshape_for_sample_show(data):\n",
" data = data.reshape(-1,8,32,32,3)\n",
" data = data.transpose(0,2,1,3,4)\n",
" return data.reshape(8*32,8*32,3)\n",
"\n",
"sample_X = reshape_for_sample_show(train_X[:64])\n",
"sample_Y = reshape_for_sample_show(train_Y[:64])\n",
"plt.figure(figsize=(16,8))\n",
"plt.subplot(1, 2, 1)\n",
"plt.title('Input')\n",
"plt.imshow(sample_X)\n",
"plt.axis('off')\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.title('Correct')\n",
"plt.imshow(sample_Y)\n",
"plt.axis('off')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SlRLeX0SwG0L",
"colab_type": "code",
"colab": {}
},
"source": [
"#適宜調整\n",
"batch_size = 180\n",
"epochs = 96#160\n",
"lr = 0.0001"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3hJSlLEteUIP",
"colab_type": "code",
"colab": {}
},
"source": [
"base_model = ResNet50(weights=None, include_top=False, input_shape=train_X.shape[1:])\n",
"model = base_model.output\n",
"model = Reshape((32, 32, 2))(model)\n",
"\n",
"# 出力層\n",
"model = Conv2D(3, (32, 32), padding='same', activation='linear')(model)\n",
"\n",
"model = Model(inputs=base_model.input, outputs=model)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VLH6t7vSSfvP",
"colab_type": "code",
"colab": {}
},
"source": [
"model.load_weights('checkpoint.h5')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qlUrmouBvVfY",
"colab_type": "code",
"colab": {}
},
"source": [
"model.compile(loss=losses.mean_squared_error,\n",
" optimizer=optimizers.Adam(lr=lr),\n",
" metrics=['mae', 'mse'])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sT3I5VaKEtI-",
"colab_type": "code",
"colab": {}
},
"source": [
"# 必要ならチェックポイント設定\n",
"checkpoint = keras.callbacks.ModelCheckpoint(filepath = 'checkpoint.h5', monitor='val_mean_squared_error', verbose=1, save_best_only=True, mode='auto')\n",
"cbs = [checkpoint]\n",
"\n",
"# 訓練\n",
"history = model.fit(train_X, train_Y, batch_size=batch_size, epochs=epochs,\n",
" verbose=1, validation_split=0.1, callbacks=cbs)\n",
"\n",
"# model.save('saved_model.h5')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZS1xZwE_vaJd",
"colab_type": "code",
"colab": {}
},
"source": [
"# 評価\n",
"score = model.evaluate(test_X, test_Y, verbose=0)\n",
"print('Test loss:', score[0])\n",
"print('Test mae:', score[1])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fvUV5EKRviZj",
"colab_type": "code",
"colab": {}
},
"source": [
"hist = pd.DataFrame(history.history)\n",
"hist['epoch'] = history.epoch\n",
"\n",
"plt.figure()\n",
"\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Mean Abs Error')\n",
"plt.plot(hist['epoch'], hist['mean_absolute_error'],\n",
" label='Train Error')\n",
"plt.plot(hist['epoch'], hist['val_mean_absolute_error'],\n",
" label = 'Val Error')\n",
"plt.legend()\n",
"\n",
"plt.figure()\n",
"\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Mean Square Error')\n",
"plt.plot(hist['epoch'], hist['mean_squared_error'],\n",
" label='Train Error')\n",
"plt.plot(hist['epoch'], hist['val_mean_squared_error'],\n",
" label = 'Val Error')\n",
"plt.legend()\n",
"plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "coHelj6pEQst",
"colab_type": "code",
"colab": {}
},
"source": [
"model.load_weights('checkpoint.h5')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nN9b-OOUUJk0",
"colab_type": "code",
"colab": {}
},
"source": [
"predicted = model.predict(test_X[:64])\n",
"\n",
"sample_input = reshape_for_sample_show(test_X[:64])\n",
"sample_correct = reshape_for_sample_show(test_Y[:64])\n",
"sample_predicted = reshape_for_sample_show(predicted)\n",
"\n",
"plt.figure(figsize=(18,6))\n",
"\n",
"# 入力\n",
"plt.subplot(1, 3, 1)\n",
"plt.title('Input')\n",
"plt.imshow(sample_input)\n",
"plt.axis('off')\n",
"\n",
"# 出力\n",
"plt.subplot(1, 3, 2)\n",
"plt.title('Predicated')\n",
"plt.imshow(sample_predicted)\n",
"plt.axis('off')\n",
"\n",
"# 正解\n",
"plt.subplot(1, 3, 3)\n",
"plt.title('Correct')\n",
"plt.imshow(sample_correct)\n",
"plt.axis('off')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7N3B6F84r_O-",
"colab_type": "code",
"colab": {}
},
"source": [
"katakana = np.expand_dims(cv2.resize(np.array(Image.open('ア.png')), dsize=(32, 32), interpolation=cv2.INTER_CUBIC), axis=0) / 255\n",
"predicted = model.predict(katakana)\n",
"plt.figure(figsize=(6,6))\n",
"plt.imshow(predicted[0])\n",
"plt.axis('off')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "n3Ox5xFypqr_",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment