Last active
December 12, 2018 15:53
-
-
Save mixuala/8467e4cfb9a6a1f586bede158dc36b29 to your computer and use it in GitHub Desktop.
how to run `train` and `validation` loops in the same tensorflow session using `TF-Slim` and `train_step_fn()` (Colaboratory notebook)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "HOWTO: tf.slim train+validate.ipynb", | |
"version": "0.3.2", | |
"views": {}, | |
"default_view": {}, | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"metadata": { | |
"id": "p2_4yjzakP8s", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# HowTo: running `train` & `validation` loops in the same session with `tf-slim`\n", | |
"\n", | |
"`Validation` loops are used to monitor `training` to identify when models reach a high variance state, e.g. `overfitting`. \n", | |
"\n", | |
"The purpose of this notebook is to provide an example of how to run a `validation` loop in the same session as a `train` loop using `tf.slim`. We'll also provide an example on how to show `validation` summaries on `tensorboard`.\n", | |
"\n", | |
"\n", | |
"\n", | |
"## references:\n", | |
"\n", | |
"* This example borrows from the [slim walkthough](https://github.com/tensorflow/models/blob/master/research/slim/slim_walkthrough.ipynb) notebook \n", | |
"* recipe for [running train/validation/test loops](https://github.com/tensorflow/tensorflow/issues/5987) in the same `tf.slim` session\n", | |
"* plot [training and validation losses](https://stackoverflow.com/questions/37146614/tensorboard-plot-training-and-validation-losses-on-the-same-graph) on the same tensorboard graph\n" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "qZ6P8ndskPJq", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"\n", | |
"\n", | |
"# Simple Example\n", | |
"\n", | |
"## Setup" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "gQMK5A4zpVm7", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
}, | |
"output_extras": [ | |
{ | |
"item_id": 25 | |
} | |
], | |
"base_uri": "https://localhost:8080/", | |
"height": 238 | |
}, | |
"outputId": "f174fb88-8c2a-4491-ef5f-9f3a730649cb", | |
"executionInfo": { | |
"status": "ok", | |
"timestamp": 1519818617790, | |
"user_tz": -480, | |
"elapsed": 32936, | |
"user": { | |
"displayName": "michael lin", | |
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg", | |
"userId": "111539764795298113840" | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"HOME = \"/content\"\n", | |
"SLIM = \"/content/models/research/slim\"\n", | |
"# load repo for TF-Slim image models\n", | |
"!git clone https://github.com/tensorflow/models.git\n", | |
"!git clone https://github.com/mixuala/colab_utils.git" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Cloning into 'models'...\n", | |
"remote: Counting objects: 12593, done.\u001b[K\n", | |
"remote: Compressing objects: 100% (2/2), done.\u001b[K\n", | |
"remote: Total 12593 (delta 2), reused 2 (delta 2), pack-reused 12589\u001b[K\n", | |
"Receiving objects: 100% (12593/12593), 410.50 MiB | 44.14 MiB/s, done.\n", | |
"Resolving deltas: 100% (7096/7096), done.\n", | |
"Checking out files: 100% (1794/1794), done.\n", | |
"Cloning into 'colab_utils'...\n", | |
"remote: Counting objects: 198, done.\u001b[K\n", | |
"remote: Compressing objects: 100% (21/21), done.\u001b[K\n", | |
"remote: Total 198 (delta 9), reused 22 (delta 6), pack-reused 171\u001b[K\n", | |
"Receiving objects: 100% (198/198), 56.79 KiB | 11.36 MiB/s, done.\n", | |
"Resolving deltas: 100% (77/77), done.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Kvag4TvZoy2S", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"from __future__ import absolute_import\n", | |
"from __future__ import division\n", | |
"from __future__ import print_function\n", | |
"\n", | |
"import os, sys, shutil\n", | |
"\n", | |
"import matplotlib\n", | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt\n", | |
"import math\n", | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"import time\n", | |
"\n", | |
"from models.research.slim.datasets import dataset_utils\n", | |
"from colab_utils import tboard\n", | |
"\n", | |
"# Main slim library\n", | |
"from tensorflow.contrib import slim" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Wy-Nm6-sruSE", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Build `tf.slim` training loop" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "aQGSBeHUsQoU", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"### \n", | |
"### dataset pre-processing and tensorboard methods\n", | |
"###\n", | |
"def produce_batch(batch_size, noise=0.3):\n", | |
" xs = np.random.random(size=[batch_size, 1]) * 10\n", | |
" ys = np.sin(xs) + 5 + np.random.normal(size=[batch_size, 1], scale=noise)\n", | |
" return [xs.astype(np.float32), ys.astype(np.float32)]\n", | |
"\n", | |
"def convert_data_to_tensors(x, y):\n", | |
" inputs = tf.constant(x)\n", | |
" inputs.set_shape([None, 1])\n", | |
" \n", | |
" outputs = tf.constant(y)\n", | |
" outputs.set_shape([None, 1])\n", | |
" return inputs, outputs\n", | |
" \n", | |
" \n", | |
"\n", | |
"def reset_tensorboard(log_dir):\n", | |
" try:\n", | |
" shutil.rmtree(log_dir)\n", | |
" except:\n", | |
" pass" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "z0iO_a9no91B", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"### \n", | |
"### net\n", | |
"###\n", | |
"def regression_model(inputs, is_training=True, scope=\"regression_model\"):\n", | |
" \"\"\"Creates the regression model.\n", | |
"\n", | |
" Args:\n", | |
" inputs: A node that yields a `Tensor` of size [batch_size, dimensions].\n", | |
" is_training: Whether or not we're currently training the model.\n", | |
" scope: An optional variable_op scope for the model.\n", | |
"\n", | |
" Returns:\n", | |
" predictions: 1-D `Tensor` of shape [batch_size] of responses.\n", | |
" end_points: A dict of end points representing the hidden layers.\n", | |
" \"\"\"\n", | |
" # Make the model, reuse weights for validation batches using reuse=tf.AUTO_REUSE\n", | |
" with slim.arg_scope([slim.fully_connected], reuse=tf.AUTO_REUSE):\n", | |
" end_points = {}\n", | |
" # Set the default weight _regularizer and acvitation for each fully_connected layer.\n", | |
" with slim.arg_scope([slim.fully_connected],\n", | |
" activation_fn=tf.nn.relu,\n", | |
" weights_regularizer=slim.l2_regularizer(0.01)):\n", | |
"\n", | |
" # Creates a fully connected layer from the inputs with 32 hidden units.\n", | |
" net = slim.fully_connected(inputs, 32, scope='fc1')\n", | |
" end_points['fc1'] = net\n", | |
"\n", | |
" # Adds a dropout layer to prevent over-fitting.\n", | |
" net = slim.dropout(net, 0.8, is_training=is_training)\n", | |
"\n", | |
" # Adds another fully connected layer with 16 hidden units.\n", | |
" net = slim.fully_connected(net, 16, scope='fc2')\n", | |
" end_points['fc2'] = net\n", | |
"\n", | |
" # Creates a fully-connected layer with a single hidden unit. Note that the\n", | |
" # layer is made linear by setting activation_fn=None.\n", | |
" predictions = slim.fully_connected(net, 1, activation_fn=None, scope='prediction')\n", | |
" end_points['out'] = predictions\n", | |
"\n", | |
" return predictions, end_points\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "_fDTDgqEswde", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
}, | |
"output_extras": [ | |
{ | |
"item_id": 6 | |
} | |
], | |
"base_uri": "https://localhost:8080/", | |
"height": 136 | |
}, | |
"outputId": "c010bf18-c420-452e-988f-bb735e5033c3", | |
"executionInfo": { | |
"status": "ok", | |
"timestamp": 1519818630934, | |
"user_tz": -480, | |
"elapsed": 6385, | |
"user": { | |
"displayName": "michael lin", | |
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg", | |
"userId": "111539764795298113840" | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"### \n", | |
"### runtime params\n", | |
"###\n", | |
"LOG_DIR = '/tmp/regression_model'\n", | |
"TENSORBOARD_RUN = LOG_DIR + \"/train\"\n", | |
"LOG_INTERVAL = 100\n", | |
"VALIDATION_INTERVAL = 500\n", | |
"STEPS = 10000\n", | |
"tboard.launch_tensorboard(bin_dir=\"/tmp\", log_dir=LOG_DIR)\n", | |
"tf.logging.set_verbosity(tf.logging.INFO)\n", | |
"\n", | |
"!ls $TENSORBOARD_RUN\n", | |
"print(TENSORBOARD_RUN)\n", | |
"\n" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"calling wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip ...\n", | |
"calling unzip ngrok-stable-linux-amd64.zip ...\n", | |
"ngrok installed. path=/tmp/ngrok\n", | |
"status: tensorboard=False, ngrok=False\n", | |
"tensorboard url= http://7ca94201.ngrok.io\n", | |
"ls: cannot access '/tmp/regression_model/train': No such file or directory\n", | |
"/tmp/regression_model/train\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "WvIQZ7ZbxIrT", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### add `validation` loop to `train_step_fn()`\n", | |
"validate on given step interval" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "5uxbPyhPw-v4", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"# \n", | |
"def train_step_fn(sess, train_op, global_step, train_step_kwargs):\n", | |
" \"\"\"\n", | |
" slim.learning.train_step():\n", | |
" train_step_kwargs = {summary_writer:, should_log:, should_stop:}\n", | |
"\n", | |
" usage: slim.learning.train( train_op, logdir, \n", | |
" train_step_fn=train_step_fn,)\n", | |
" \"\"\"\n", | |
" if hasattr(train_step_fn, 'step'):\n", | |
" train_step_fn.step += 1 # or use global_step.eval(session=sess)\n", | |
" else:\n", | |
" train_step_fn.step = global_step.eval(sess)\n", | |
" \n", | |
" # calc training losses\n", | |
" total_loss, should_stop = slim.learning.train_step(sess, train_op, global_step, train_step_kwargs)\n", | |
" \n", | |
" \n", | |
" # validate on interval\n", | |
" if train_step_fn.step % VALIDATION_INTERVAL == 0:\n", | |
" validiate_loss, validation_delta = sess.run([val_loss, summary_validation_delta])\n", | |
" print(\">> global step {}: train={} validation={} delta={}\".format(train_step_fn.step, \n", | |
" total_loss, validiate_loss, validiate_loss-total_loss))\n", | |
" \n", | |
"\n", | |
" return [total_loss, should_stop]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "HYvY5RYNtGRL", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### build graph for training loop" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "IhwzhDFttEoh", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"# get data\n", | |
"x_train, y_train = produce_batch(320)\n", | |
"x_validation, y_validation = produce_batch(320)\n", | |
"\n", | |
"\n", | |
"\n", | |
"## graph\n", | |
"with tf.Graph().as_default():\n", | |
" \n", | |
" ### get test/valid datasets\n", | |
" inputs, targets = convert_data_to_tensors(x_train, y_train)\n", | |
" val_inputs, val_targets = convert_data_to_tensors(x_validation, y_validation)\n", | |
" \n", | |
" ### seems to work \n", | |
"# with tf.variable_scope(\"model\") as scope:\n", | |
" # Make the model, reuse weights for validation batches\n", | |
" # NOTE: do I need to use a separate scope only for the model to ensure\n", | |
" # I reuse ONLY the model weights between train/validate?\n", | |
" predictions, nodes = regression_model(inputs, is_training=True)\n", | |
" val_predictions, _ = regression_model(val_inputs, is_training=False)\n", | |
" \n", | |
" \n", | |
"\n", | |
" ###\n", | |
" ### train graph\n", | |
" ###\n", | |
" # with tf.variable_scope(\"train\"):\n", | |
" loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)\n", | |
" \n", | |
" ###\n", | |
" ### validation graph\n", | |
" ### evaluate from `train_step_fn()`, usually after each epoch\n", | |
" ###\n", | |
" # careful, use different loss_collection so you don't add validation losses to training losses\n", | |
" val_loss = tf.losses.mean_squared_error(labels=val_targets, predictions=val_predictions,\n", | |
" loss_collection=\"validation\" \n", | |
" )\n", | |
" \n", | |
" ### train_op\n", | |
" total_loss = tf.losses.get_total_loss() # excludes loss_collection=\"validation\"\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate=0.005)\n", | |
" train_op = slim.learning.create_train_op(total_loss, optimizer)\n", | |
" print(\"\\n >> total_loss=\", total_loss) \n", | |
" \n", | |
" \n", | |
" ### add summaries\n", | |
" # train summaries\n", | |
" summary_loss = tf.summary.scalar(\"train/loss\", total_loss)\n", | |
" train_writer = tf.summary.FileWriter(TENSORBOARD_RUN)\n", | |
"\n", | |
" # validation summaries\n", | |
" summary_validation_loss = tf.summary.scalar(\"validation/loss\", val_loss )\n", | |
" summary_validation_delta = tf.summary.scalar(\"validation/loss_delta\", (val_loss-loss) ) \n", | |
" print(\"\\n >> validation losses=\", tf.losses.get_losses(loss_collection=\"validation\"))\n", | |
"\n", | |
"\n", | |
" \n", | |
" if True: reset_tensorboard(LOG_DIR)\n", | |
" \n", | |
" \n", | |
" if not tf.gfile.Exists(TENSORBOARD_RUN): tf.gfile.MakeDirs(TENSORBOARD_RUN)\n", | |
" tboard.launch_tensorboard(bin_dir=\"/tmp\", log_dir=LOG_DIR)\n", | |
"\n", | |
" # Run the training inside a session.\n", | |
" final_loss = slim.learning.train(\n", | |
" train_op,\n", | |
" train_step_fn=train_step_fn,\n", | |
" logdir=LOG_DIR,\n", | |
" number_of_steps=STEPS,\n", | |
"# summary_op=train_summary_op,\n", | |
" summary_writer=train_writer,\n", | |
" save_summaries_secs=10,\n", | |
" save_interval_secs=1000,\n", | |
" log_every_n_steps=LOG_INTERVAL,\n", | |
" )\n", | |
" \n", | |
"print(\"Finished training. Last batch loss:\", final_loss)\n", | |
"print(\"Checkpoint saved in %s\" % LOG_DIR)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "ZSmw3ijtC9fj", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# using `TFRecord` datasets\n", | |
"uses the flowers dataset from slim models" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "uhfASQkoDNMH", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
}, | |
"output_extras": [ | |
{ | |
"item_id": 57 | |
}, | |
{ | |
"item_id": 58 | |
} | |
], | |
"base_uri": "https://localhost:8080/", | |
"height": 68 | |
}, | |
"outputId": "c9c4256d-2248-444d-f853-7108c89e3ade", | |
"executionInfo": { | |
"status": "ok", | |
"timestamp": 1519819660617, | |
"user_tz": -480, | |
"elapsed": 140029, | |
"user": { | |
"displayName": "michael lin", | |
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg", | |
"userId": "111539764795298113840" | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"%cd $SLIM\n", | |
"import tensorflow as tf\n", | |
"from datasets import dataset_utils\n", | |
"\n", | |
"url = \"http://download.tensorflow.org/data/flowers.tar.gz\"\n", | |
"flowers_data_dir = '/tmp/flowers'\n", | |
"\n", | |
"if not tf.gfile.Exists(flowers_data_dir):\n", | |
" tf.gfile.MakeDirs(flowers_data_dir)\n", | |
"\n", | |
"dataset_utils.download_and_uncompress_tarball(url, flowers_data_dir) " | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content/models/research/slim\n", | |
">> Downloading flowers.tar.gz 100.0%\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Successfully downloaded flowers.tar.gz 228649660 bytes.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "wEaA4O3hDjGR", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
}, | |
"output_extras": [ | |
{ | |
"item_id": 1 | |
} | |
], | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "c318be8f-a4a2-43b0-ef81-a873c7eb3f9c", | |
"executionInfo": { | |
"status": "ok", | |
"timestamp": 1519819842687, | |
"user_tz": -480, | |
"elapsed": 1094, | |
"user": { | |
"displayName": "michael lin", | |
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg", | |
"userId": "111539764795298113840" | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"%cd $SLIM\n", | |
"from preprocessing import inception_preprocessing\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"from tensorflow.contrib import slim\n", | |
"\n", | |
"\n", | |
"def load_batch(dataset, batch_size=32, height=299, width=299, is_training=False):\n", | |
" \"\"\"Loads a single batch of data.\n", | |
" \n", | |
" Args:\n", | |
" dataset: The dataset to load.\n", | |
" batch_size: The number of images in the batch.\n", | |
" height: The size of each image after preprocessing.\n", | |
" width: The size of each image after preprocessing.\n", | |
" is_training: Whether or not we're currently training or evaluating.\n", | |
" \n", | |
" Returns:\n", | |
" images: A Tensor of size [batch_size, height, width, 3], image samples that have been preprocessed.\n", | |
" images_raw: A Tensor of size [batch_size, height, width, 3], image samples that can be used for visualization.\n", | |
" labels: A Tensor of size [batch_size], whose values range between 0 and dataset.num_classes.\n", | |
" \"\"\"\n", | |
" data_provider = slim.dataset_data_provider.DatasetDataProvider(\n", | |
" dataset, common_queue_capacity=32,\n", | |
" common_queue_min=8)\n", | |
" image_raw, label = data_provider.get(['image', 'label'])\n", | |
" \n", | |
" # Preprocess image for usage by Inception.\n", | |
" image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training)\n", | |
" \n", | |
" # Preprocess the image for display purposes.\n", | |
" image_raw = tf.expand_dims(image_raw, 0)\n", | |
" image_raw = tf.image.resize_images(image_raw, [height, width])\n", | |
" image_raw = tf.squeeze(image_raw)\n", | |
"\n", | |
" # Batch it up.\n", | |
" images, images_raw, labels = tf.train.batch(\n", | |
" [image, image_raw, label],\n", | |
" batch_size=batch_size,\n", | |
" num_threads=1,\n", | |
" capacity=2 * batch_size)\n", | |
" \n", | |
" return images, images_raw, labels" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content/models/research/slim\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "43_I3sJwGALj", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"## runtime params\n", | |
"VALIDATION_INTERVAL = 1" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "0YmY9TfhFwDB", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"train_step = slim.learning.train_step\n", | |
"\n", | |
"# slim.learning.train(train_step_fn=)\n", | |
"def train_step_fn(sess, train_op, global_step, train_step_kwargs):\n", | |
" \"\"\"\n", | |
" slim.learning.train_step():\n", | |
" train_step_kwargs = {summary_writer:, should_log:, should_stop:}\n", | |
" \"\"\"\n", | |
" if hasattr(train_step_fn, 'step'):\n", | |
" train_step_fn.step += 1 # or use global_step.eval(session=sess)\n", | |
" else:\n", | |
" train_step_fn.step = global_step.eval(sess)\n", | |
" \n", | |
" # calc training losses\n", | |
" total_loss, should_stop = train_step(sess, train_op, global_step, train_step_kwargs)\n", | |
" \n", | |
" \n", | |
" # validate on interval\n", | |
" if train_step_fn.step % VALIDATION_INTERVAL == 0:\n", | |
" np_train_loss, np_val_loss, _ = sess.run([train_loss, val_loss, summary_validation_delta])\n", | |
" print(\">> global step {}: train={} validation={} delta={}\".format(train_step_fn.step, \n", | |
" np_train_loss, np_val_loss, np_val_loss-np_train_loss))\n", | |
" \n", | |
"\n", | |
" return [total_loss, should_stop]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "ZZr9q1KbDb8u", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"def my_cnn(images, num_classes, is_training): # is_training is not used...\n", | |
" with slim.arg_scope([slim.conv2d, slim.fully_connected], reuse=tf.AUTO_REUSE):\n", | |
" with slim.arg_scope([slim.max_pool2d], kernel_size=[3, 3], stride=2):\n", | |
" net = slim.conv2d(images, 64, [5, 5], scope=\"conv1\")\n", | |
" net = slim.max_pool2d(net)\n", | |
" net = slim.conv2d(net, 64, [5, 5], scope=\"conv2\")\n", | |
" net = slim.max_pool2d(net)\n", | |
" net = slim.flatten(net)\n", | |
" net = slim.fully_connected(net, 192, scope=\"fc1\")\n", | |
" net = slim.fully_connected(net, num_classes, activation_fn=None, scope=\"fc2\") \n", | |
" return net" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "xxccWeNWLcWM", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## training loop" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Tx0FkgqPDUMX", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
}, | |
"output_extras": [ | |
{ | |
"item_id": 13 | |
} | |
], | |
"base_uri": "https://localhost:8080/", | |
"height": 323 | |
}, | |
"outputId": "c3846c75-4e0d-44b8-d3b9-8a6a51533def", | |
"executionInfo": { | |
"status": "ok", | |
"timestamp": 1519821210448, | |
"user_tz": -480, | |
"elapsed": 95100, | |
"user": { | |
"displayName": "michael lin", | |
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg", | |
"userId": "111539764795298113840" | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"from datasets import flowers\n", | |
"\n", | |
"# This might take a few minutes.\n", | |
"train_dir = LOG_DIR = '/tmp/flowers_model/'\n", | |
"print('Will save model to %s' % train_dir)\n", | |
"\n", | |
"tboard.launch_tensorboard(bin_dir=\"/tmp\", log_dir=LOG_DIR)\n", | |
"\n", | |
"with tf.Graph().as_default():\n", | |
" tf.logging.set_verbosity(tf.logging.INFO)\n", | |
"\n", | |
" dataset = flowers.get_split('train', flowers_data_dir)\n", | |
" images, _, labels = load_batch(dataset)\n", | |
" \n", | |
" val_dataset = flowers.get_split('validation', flowers_data_dir)\n", | |
" val_images, _, val_labels = load_batch(val_dataset)\n", | |
" \n", | |
" \n", | |
" # Create the model:\n", | |
" logits = my_cnn(images, num_classes=dataset.num_classes, is_training=True)\n", | |
" val_logits = my_cnn(val_images, num_classes=dataset.num_classes, is_training=False)\n", | |
" \n", | |
" # Specify the `train` loss function:\n", | |
" one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)\n", | |
" train_loss = tf.losses.softmax_cross_entropy (one_hot_labels, logits)\n", | |
" total_loss = slim.losses.get_total_loss()\n", | |
" \n", | |
" # Specify the `validation` loss function:\n", | |
" val_one_hot_labels = slim.one_hot_encoding(val_labels, dataset.num_classes)\n", | |
" val_loss = tf.losses.softmax_cross_entropy (val_one_hot_labels, val_logits, \n", | |
" loss_collection=\"validation\")\n", | |
" \n", | |
"\n", | |
" # Create some summaries to visualize the training process:\n", | |
" tf.summary.scalar('train/Total_Loss', total_loss)\n", | |
" tf.summary.scalar('validation/Validation_Loss', val_loss)\n", | |
" summary_validation_delta = tf.summary.scalar('validation/Validation_Delta', (val_loss - train_loss))\n", | |
" \n", | |
" # Specify the optimizer and create the train op:\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate=0.01)\n", | |
" train_op = slim.learning.create_train_op(total_loss, optimizer)\n", | |
"\n", | |
" if True: reset_tensorboard(LOG_DIR)\n", | |
" \n", | |
" # Run the training:\n", | |
" final_loss = slim.learning.train(\n", | |
" train_op,\n", | |
" train_step_fn=train_step_fn,\n", | |
" logdir=train_dir,\n", | |
" number_of_steps=1, # For speed, we just do 1 epoch\n", | |
" save_summaries_secs=1)\n", | |
" \n", | |
" print('Finished training. Final batch loss %d' % final_loss)" | |
], | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Will save model to /tmp/flowers_model/\n", | |
"ngrok installed\n", | |
"status: tensorboard=True, ngrok=True\n", | |
"tensorboard url= https://7ca94201.ngrok.io\n", | |
"INFO:tensorflow:Running local_init_op.\n", | |
"INFO:tensorflow:Done running local_init_op.\n", | |
"INFO:tensorflow:Starting Session.\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/flowers_model/model.ckpt\n", | |
"INFO:tensorflow:Starting Queues.\n", | |
"INFO:tensorflow:Recording summary at step 0.\n", | |
"INFO:tensorflow:Recording summary at step 0.\n", | |
"INFO:tensorflow:global step 1: loss = 1.6314 (51.359 sec/step)\n", | |
"INFO:tensorflow:Recording summary at step 1.\n", | |
"INFO:tensorflow:Recording summary at step 1.\n", | |
">> global step 0: train=5087.9873046875 validation=7657.87646484375 delta=2569.88916015625\n", | |
"INFO:tensorflow:Stopping Training.\n", | |
"INFO:tensorflow:Finished training! Saving model to disk.\n", | |
"Finished training. Final batch loss 1\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "QivJ4FtsE_z4", | |
"colab_type": "code", | |
"colab": { | |
"autoexec": { | |
"startup": false, | |
"wait_interval": 0 | |
} | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment