Skip to content

Instantly share code, notes, and snippets.

@maxoobot
Last active March 21, 2020 04:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxoobot/0fb5d164c6c71bf0782c51883ef02a07 to your computer and use it in GitHub Desktop.
Save maxoobot/0fb5d164c6c71bf0782c51883ef02a07 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"uuid": "a3ebf42a-b439-4688-9789-899e28ddc200"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0-alpha0\n"
]
}
],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import StratifiedShuffleSplit\n",
"print(tf.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"uuid": "57863da7-465a-4cd8-b296-04daa5ef770d"
},
"outputs": [],
"source": [
"tf.set_random_seed(42)"
]
},
{
"cell_type": "markdown",
"metadata": {
"uuid": "fa0730d8-2d4a-4517-b863-baadcc7c6a42"
},
"source": [
"### データの準備"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"uuid": "b0e88dd7-9f39-49e4-9eda-b1093fb00d20"
},
"outputs": [],
"source": [
"# データの読み込み\n",
"train_data = np.genfromtxt(\"/nas/data/fashion-mnist_train.csv\", delimiter=',', skip_header=1)\n",
"test_data = np.genfromtxt(\"/nas/data/fashion-mnist_test.csv\", delimiter=',', skip_header=1)\n",
"\n",
"train_images, train_labels = train_data[:, 1:], train_data[:, 0]\n",
"test_images, test_labels = test_data[:, 1:], test_data[:, 0]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"uuid": "ddae21b1-48d9-4d1c-854e-b174722693a4"
},
"outputs": [],
"source": [
"# 入力データ(画像情報)を2Dに変換\n",
"X = train_images.reshape((-1, 28, 28, 1))\n",
"X_test = test_images.reshape((-1, 28, 28, 1))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"uuid": "4dc12b9a-dc9b-4d34-a0be-29587d63cc0b"
},
"outputs": [],
"source": [
"# 出力データ(ラベル)をone-hotエンコード\n",
"classes_to_index = dict((c, i) for i, c in enumerate(np.unique(train_labels)))\n",
"\n",
"y = np.zeros((train_labels.shape[0], len(np.unique(train_labels))))\n",
"y_test = np.zeros((test_labels.shape[0], len(np.unique(test_labels))))\n",
"\n",
"for i, c in enumerate(train_labels):\n",
" y[i, classes_to_index[c]] = 1\n",
" \n",
"for i, c in enumerate(test_labels):\n",
" y_test[i, classes_to_index[c]] = 1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"uuid": "2e52e630-d56b-48be-90ac-2a2518949f3d"
},
"outputs": [],
"source": [
"# validationデータを作成\n",
"strat_split = StratifiedShuffleSplit(n_splits=1, test_size=1/6, random_state=1)\n",
"for train_index, valid_index in strat_split.split(X, y):\n",
" X_train, X_valid = X[train_index], X[valid_index]\n",
" y_train, y_valid = y[train_index], y[valid_index]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"uuid": "5030db41-a952-4acd-a1e5-97a3562b118a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(50000, 28, 28, 1) (10000, 28, 28, 1) (10000, 28, 28, 1)\n",
"(50000, 10) (10000, 10) (10000, 10)\n"
]
}
],
"source": [
"print(X_train.shape, X_valid.shape, X_test.shape)\n",
"print(y_train.shape, y_valid.shape, y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"cellType": "markdown",
"uuid": "6ba50475-02c1-4c86-92a4-8cbc6a73640d"
},
"source": [
"### モデル構築"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"uuid": "970132bb-553e-4e62-85a7-0d46614687c8"
},
"outputs": [],
"source": [
"inputs = tf.keras.Input(shape=(28, 28, 1))\n",
"conv1 = tf.keras.layers.Conv2D(32, 5, activation='relu')(inputs)\n",
"pool1 = tf.keras.layers.MaxPooling2D(2)(conv1)\n",
"conv2 = tf.keras.layers.Conv2D(64, 5, activation='relu')(pool1)\n",
"pool2 = tf.keras.layers.MaxPooling2D(2)(conv2)\n",
"flatten = tf.keras.layers.Flatten()(pool2)\n",
"dense = tf.keras.layers.Dense(1024, activation='relu')(flatten)\n",
"dropout = tf.keras.layers.Dropout(0.4)(dense)\n",
"outputs= tf.keras.layers.Dense(10, activation='softmax')(dropout)\n",
"\n",
"model = tf.keras.Model(inputs=inputs, outputs=outputs, name=\"fashion_mnist_model\")\n",
"model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"uuid": "96bdca2d-15a7-4fcb-8821-55eaea848aa2"
},
"source": [
"### モデル学習"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"uuid": "82c6edfe-c85c-4842-84b2-29c23daf17a4"
},
"outputs": [],
"source": [
"checkpoint_dir = \"/nas/model/checkpoints/\"\n",
"if not os.path.exists(checkpoint_dir):\n",
" os.makedirs(checkpoint_dir)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"uuid": "cb9421cf-81b2-4698-916e-fe18e0ce38b1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/20\n",
"50000/50000 [==============================] - 8s 161us/sample - loss: 0.6825 - acc: 0.8022 - val_loss: 0.3931 - val_acc: 0.8543\n",
"Epoch 2/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.4081 - acc: 0.8522 - val_loss: 0.4430 - val_acc: 0.8485\n",
"Epoch 3/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.3856 - acc: 0.8609 - val_loss: 0.3789 - val_acc: 0.8671\n",
"Epoch 4/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.3622 - acc: 0.8687 - val_loss: 0.3890 - val_acc: 0.8665\n",
"Epoch 5/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.3473 - acc: 0.8739 - val_loss: 0.3384 - val_acc: 0.8818\n",
"Epoch 6/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.3308 - acc: 0.8794 - val_loss: 0.3291 - val_acc: 0.8847\n",
"Epoch 7/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.3249 - acc: 0.8823 - val_loss: 0.3443 - val_acc: 0.8797\n",
"Epoch 8/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.3097 - acc: 0.8876 - val_loss: 0.3466 - val_acc: 0.8781\n",
"Epoch 9/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.3127 - acc: 0.8866 - val_loss: 0.3484 - val_acc: 0.8811\n",
"Epoch 10/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.3018 - acc: 0.8931 - val_loss: 0.3623 - val_acc: 0.8691\n",
"Epoch 11/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.2912 - acc: 0.8954 - val_loss: 0.3603 - val_acc: 0.8872\n",
"Epoch 12/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.2809 - acc: 0.8990 - val_loss: 0.3735 - val_acc: 0.8864\n",
"Epoch 13/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.2796 - acc: 0.8993 - val_loss: 0.3664 - val_acc: 0.8862\n",
"Epoch 14/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.2805 - acc: 0.8993 - val_loss: 0.3920 - val_acc: 0.8822\n",
"Epoch 15/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.2739 - acc: 0.9036 - val_loss: 0.3779 - val_acc: 0.8818\n",
"Epoch 16/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.2690 - acc: 0.9049 - val_loss: 0.3813 - val_acc: 0.8836\n",
"Epoch 17/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.2617 - acc: 0.9077 - val_loss: 0.3980 - val_acc: 0.8845\n",
"Epoch 18/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.2607 - acc: 0.9089 - val_loss: 0.4079 - val_acc: 0.8931\n",
"Epoch 19/20\n",
"50000/50000 [==============================] - 7s 134us/sample - loss: 0.2462 - acc: 0.9134 - val_loss: 0.4304 - val_acc: 0.8855\n",
"Epoch 20/20\n",
"50000/50000 [==============================] - 7s 135us/sample - loss: 0.2507 - acc: 0.9115 - val_loss: 0.4088 - val_acc: 0.8911\n"
]
}
],
"source": [
"model_checkpoint = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir+'fashion_mnist{epoch:02d}.h5', period=1,save_weights_only=True)\n",
"history = model.fit(X_train, y_train, epochs=20, verbose=1, callbacks=[model_checkpoint], validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"uuid": "c8e8b8e9-1dfa-46bb-9a63-c6672d4ba31f"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['acc'])\n",
"plt.plot(history.history['val_acc'])\n",
"plt.ylabel('accuracy')\n",
"plt.xlabel('epoch')\n",
"plt.legend(['train', 'validation'], loc='upper left')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"uuid": "f1756568-6caf-4c3a-a69a-16328aca2dee"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['loss'])\n",
"plt.plot(history.history['val_loss'])\n",
"plt.ylabel('loss')\n",
"plt.xlabel('epoch')\n",
"plt.legend(['train', 'validation'], loc='upper left')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"uuid": "ef3cdc4f-18dd-45b7-ae0a-b4cbc2ea49d6"
},
"outputs": [],
"source": [
"# モデルを前のエポック状態に戻す\n",
"model.load_weights(checkpoint_dir+'fashion_mnist05.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {
"uuid": "2330b2e8-ef4e-42e4-8469-93793ba03c3f"
},
"source": [
"### モデル評価"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"uuid": "37289f6a-aebf-4066-8653-df9f1f89f58e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 1s 55us/sample - loss: 0.3492 - acc: 0.8843\n"
]
},
{
"data": {
"text/plain": [
"[0.3492289930522442, 0.8843]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"cellType": "markdown",
"uuid": "831865f4-039b-48ad-a82a-ea21a56ae121"
},
"source": [
"### モデル保存"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"uuid": "c9a397be-f653-4266-973d-dc77bc77f137"
},
"outputs": [],
"source": [
"directory = \"./\"\n",
"if not os.path.exists(directory):\n",
" os.makedirs(directory)\n",
"model.save(directory + \"mnist_model.h5\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"uuid": "ef022bf0-c897-4772-9178-61fdb424cd89"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Tensorflow 2.0",
"language": "python",
"name": "tf2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment