Skip to content

Instantly share code, notes, and snippets.

@mixuala
Last active December 12, 2018 15:53
Show Gist options
  • Save mixuala/8467e4cfb9a6a1f586bede158dc36b29 to your computer and use it in GitHub Desktop.
Save mixuala/8467e4cfb9a6a1f586bede158dc36b29 to your computer and use it in GitHub Desktop.
how to run `train` and `validation` loops in the same tensorflow session using `TF-Slim` and `train_step_fn()` (Colaboratory notebook)
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "HOWTO: tf.slim train+validate.ipynb",
"version": "0.3.2",
"views": {},
"default_view": {},
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"metadata": {
"id": "p2_4yjzakP8s",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# HowTo: running `train` & `validation` loops in the same session with `tf-slim`\n",
"\n",
"`Validation` loops are used to monitor `training` to identify when models reach a high variance state, e.g. `overfitting`. \n",
"\n",
"The purpose of this notebook is to provide an example of how to run a `validation` loop in the same session as a `train` loop using `tf.slim`. We'll also provide an example on how to show `validation` summaries on `tensorboard`.\n",
"\n",
"\n",
"\n",
"## references:\n",
"\n",
"* This example borrows from the [slim walkthough](https://github.com/tensorflow/models/blob/master/research/slim/slim_walkthrough.ipynb) notebook \n",
"* recipe for [running train/validation/test loops](https://github.com/tensorflow/tensorflow/issues/5987) in the same `tf.slim` session\n",
"* plot [training and validation losses](https://stackoverflow.com/questions/37146614/tensorboard-plot-training-and-validation-losses-on-the-same-graph) on the same tensorboard graph\n"
]
},
{
"metadata": {
"id": "qZ6P8ndskPJq",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"\n",
"\n",
"# Simple Example\n",
"\n",
"## Setup"
]
},
{
"metadata": {
"id": "gQMK5A4zpVm7",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 25
}
],
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "f174fb88-8c2a-4491-ef5f-9f3a730649cb",
"executionInfo": {
"status": "ok",
"timestamp": 1519818617790,
"user_tz": -480,
"elapsed": 32936,
"user": {
"displayName": "michael lin",
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg",
"userId": "111539764795298113840"
}
}
},
"cell_type": "code",
"source": [
"HOME = \"/content\"\n",
"SLIM = \"/content/models/research/slim\"\n",
"# load repo for TF-Slim image models\n",
"!git clone https://github.com/tensorflow/models.git\n",
"!git clone https://github.com/mixuala/colab_utils.git"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'models'...\n",
"remote: Counting objects: 12593, done.\u001b[K\n",
"remote: Compressing objects: 100% (2/2), done.\u001b[K\n",
"remote: Total 12593 (delta 2), reused 2 (delta 2), pack-reused 12589\u001b[K\n",
"Receiving objects: 100% (12593/12593), 410.50 MiB | 44.14 MiB/s, done.\n",
"Resolving deltas: 100% (7096/7096), done.\n",
"Checking out files: 100% (1794/1794), done.\n",
"Cloning into 'colab_utils'...\n",
"remote: Counting objects: 198, done.\u001b[K\n",
"remote: Compressing objects: 100% (21/21), done.\u001b[K\n",
"remote: Total 198 (delta 9), reused 22 (delta 6), pack-reused 171\u001b[K\n",
"Receiving objects: 100% (198/198), 56.79 KiB | 11.36 MiB/s, done.\n",
"Resolving deltas: 100% (77/77), done.\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "Kvag4TvZoy2S",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"\n",
"import os, sys, shutil\n",
"\n",
"import matplotlib\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import math\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import time\n",
"\n",
"from models.research.slim.datasets import dataset_utils\n",
"from colab_utils import tboard\n",
"\n",
"# Main slim library\n",
"from tensorflow.contrib import slim"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Wy-Nm6-sruSE",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Build `tf.slim` training loop"
]
},
{
"metadata": {
"id": "aQGSBeHUsQoU",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"### \n",
"### dataset pre-processing and tensorboard methods\n",
"###\n",
"def produce_batch(batch_size, noise=0.3):\n",
" xs = np.random.random(size=[batch_size, 1]) * 10\n",
" ys = np.sin(xs) + 5 + np.random.normal(size=[batch_size, 1], scale=noise)\n",
" return [xs.astype(np.float32), ys.astype(np.float32)]\n",
"\n",
"def convert_data_to_tensors(x, y):\n",
" inputs = tf.constant(x)\n",
" inputs.set_shape([None, 1])\n",
" \n",
" outputs = tf.constant(y)\n",
" outputs.set_shape([None, 1])\n",
" return inputs, outputs\n",
" \n",
" \n",
"\n",
"def reset_tensorboard(log_dir):\n",
" try:\n",
" shutil.rmtree(log_dir)\n",
" except:\n",
" pass"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "z0iO_a9no91B",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"### \n",
"### net\n",
"###\n",
"def regression_model(inputs, is_training=True, scope=\"regression_model\"):\n",
" \"\"\"Creates the regression model.\n",
"\n",
" Args:\n",
" inputs: A node that yields a `Tensor` of size [batch_size, dimensions].\n",
" is_training: Whether or not we're currently training the model.\n",
" scope: An optional variable_op scope for the model.\n",
"\n",
" Returns:\n",
" predictions: 1-D `Tensor` of shape [batch_size] of responses.\n",
" end_points: A dict of end points representing the hidden layers.\n",
" \"\"\"\n",
" # Make the model, reuse weights for validation batches using reuse=tf.AUTO_REUSE\n",
" with slim.arg_scope([slim.fully_connected], reuse=tf.AUTO_REUSE):\n",
" end_points = {}\n",
" # Set the default weight _regularizer and acvitation for each fully_connected layer.\n",
" with slim.arg_scope([slim.fully_connected],\n",
" activation_fn=tf.nn.relu,\n",
" weights_regularizer=slim.l2_regularizer(0.01)):\n",
"\n",
" # Creates a fully connected layer from the inputs with 32 hidden units.\n",
" net = slim.fully_connected(inputs, 32, scope='fc1')\n",
" end_points['fc1'] = net\n",
"\n",
" # Adds a dropout layer to prevent over-fitting.\n",
" net = slim.dropout(net, 0.8, is_training=is_training)\n",
"\n",
" # Adds another fully connected layer with 16 hidden units.\n",
" net = slim.fully_connected(net, 16, scope='fc2')\n",
" end_points['fc2'] = net\n",
"\n",
" # Creates a fully-connected layer with a single hidden unit. Note that the\n",
" # layer is made linear by setting activation_fn=None.\n",
" predictions = slim.fully_connected(net, 1, activation_fn=None, scope='prediction')\n",
" end_points['out'] = predictions\n",
"\n",
" return predictions, end_points\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "_fDTDgqEswde",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 6
}
],
"base_uri": "https://localhost:8080/",
"height": 136
},
"outputId": "c010bf18-c420-452e-988f-bb735e5033c3",
"executionInfo": {
"status": "ok",
"timestamp": 1519818630934,
"user_tz": -480,
"elapsed": 6385,
"user": {
"displayName": "michael lin",
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg",
"userId": "111539764795298113840"
}
}
},
"cell_type": "code",
"source": [
"### \n",
"### runtime params\n",
"###\n",
"LOG_DIR = '/tmp/regression_model'\n",
"TENSORBOARD_RUN = LOG_DIR + \"/train\"\n",
"LOG_INTERVAL = 100\n",
"VALIDATION_INTERVAL = 500\n",
"STEPS = 10000\n",
"tboard.launch_tensorboard(bin_dir=\"/tmp\", log_dir=LOG_DIR)\n",
"tf.logging.set_verbosity(tf.logging.INFO)\n",
"\n",
"!ls $TENSORBOARD_RUN\n",
"print(TENSORBOARD_RUN)\n",
"\n"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"calling wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip ...\n",
"calling unzip ngrok-stable-linux-amd64.zip ...\n",
"ngrok installed. path=/tmp/ngrok\n",
"status: tensorboard=False, ngrok=False\n",
"tensorboard url= http://7ca94201.ngrok.io\n",
"ls: cannot access '/tmp/regression_model/train': No such file or directory\n",
"/tmp/regression_model/train\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "WvIQZ7ZbxIrT",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### add `validation` loop to `train_step_fn()`\n",
"validate on given step interval"
]
},
{
"metadata": {
"id": "5uxbPyhPw-v4",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"\n",
"# \n",
"def train_step_fn(sess, train_op, global_step, train_step_kwargs):\n",
" \"\"\"\n",
" slim.learning.train_step():\n",
" train_step_kwargs = {summary_writer:, should_log:, should_stop:}\n",
"\n",
" usage: slim.learning.train( train_op, logdir, \n",
" train_step_fn=train_step_fn,)\n",
" \"\"\"\n",
" if hasattr(train_step_fn, 'step'):\n",
" train_step_fn.step += 1 # or use global_step.eval(session=sess)\n",
" else:\n",
" train_step_fn.step = global_step.eval(sess)\n",
" \n",
" # calc training losses\n",
" total_loss, should_stop = slim.learning.train_step(sess, train_op, global_step, train_step_kwargs)\n",
" \n",
" \n",
" # validate on interval\n",
" if train_step_fn.step % VALIDATION_INTERVAL == 0:\n",
" validiate_loss, validation_delta = sess.run([val_loss, summary_validation_delta])\n",
" print(\">> global step {}: train={} validation={} delta={}\".format(train_step_fn.step, \n",
" total_loss, validiate_loss, validiate_loss-total_loss))\n",
" \n",
"\n",
" return [total_loss, should_stop]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "HYvY5RYNtGRL",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### build graph for training loop"
]
},
{
"metadata": {
"id": "IhwzhDFttEoh",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# get data\n",
"x_train, y_train = produce_batch(320)\n",
"x_validation, y_validation = produce_batch(320)\n",
"\n",
"\n",
"\n",
"## graph\n",
"with tf.Graph().as_default():\n",
" \n",
" ### get test/valid datasets\n",
" inputs, targets = convert_data_to_tensors(x_train, y_train)\n",
" val_inputs, val_targets = convert_data_to_tensors(x_validation, y_validation)\n",
" \n",
" ### seems to work \n",
"# with tf.variable_scope(\"model\") as scope:\n",
" # Make the model, reuse weights for validation batches\n",
" # NOTE: do I need to use a separate scope only for the model to ensure\n",
" # I reuse ONLY the model weights between train/validate?\n",
" predictions, nodes = regression_model(inputs, is_training=True)\n",
" val_predictions, _ = regression_model(val_inputs, is_training=False)\n",
" \n",
" \n",
"\n",
" ###\n",
" ### train graph\n",
" ###\n",
" # with tf.variable_scope(\"train\"):\n",
" loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)\n",
" \n",
" ###\n",
" ### validation graph\n",
" ### evaluate from `train_step_fn()`, usually after each epoch\n",
" ###\n",
" # careful, use different loss_collection so you don't add validation losses to training losses\n",
" val_loss = tf.losses.mean_squared_error(labels=val_targets, predictions=val_predictions,\n",
" loss_collection=\"validation\" \n",
" )\n",
" \n",
" ### train_op\n",
" total_loss = tf.losses.get_total_loss() # excludes loss_collection=\"validation\"\n",
" optimizer = tf.train.AdamOptimizer(learning_rate=0.005)\n",
" train_op = slim.learning.create_train_op(total_loss, optimizer)\n",
" print(\"\\n >> total_loss=\", total_loss) \n",
" \n",
" \n",
" ### add summaries\n",
" # train summaries\n",
" summary_loss = tf.summary.scalar(\"train/loss\", total_loss)\n",
" train_writer = tf.summary.FileWriter(TENSORBOARD_RUN)\n",
"\n",
" # validation summaries\n",
" summary_validation_loss = tf.summary.scalar(\"validation/loss\", val_loss )\n",
" summary_validation_delta = tf.summary.scalar(\"validation/loss_delta\", (val_loss-loss) ) \n",
" print(\"\\n >> validation losses=\", tf.losses.get_losses(loss_collection=\"validation\"))\n",
"\n",
"\n",
" \n",
" if True: reset_tensorboard(LOG_DIR)\n",
" \n",
" \n",
" if not tf.gfile.Exists(TENSORBOARD_RUN): tf.gfile.MakeDirs(TENSORBOARD_RUN)\n",
" tboard.launch_tensorboard(bin_dir=\"/tmp\", log_dir=LOG_DIR)\n",
"\n",
" # Run the training inside a session.\n",
" final_loss = slim.learning.train(\n",
" train_op,\n",
" train_step_fn=train_step_fn,\n",
" logdir=LOG_DIR,\n",
" number_of_steps=STEPS,\n",
"# summary_op=train_summary_op,\n",
" summary_writer=train_writer,\n",
" save_summaries_secs=10,\n",
" save_interval_secs=1000,\n",
" log_every_n_steps=LOG_INTERVAL,\n",
" )\n",
" \n",
"print(\"Finished training. Last batch loss:\", final_loss)\n",
"print(\"Checkpoint saved in %s\" % LOG_DIR)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ZSmw3ijtC9fj",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# using `TFRecord` datasets\n",
"uses the flowers dataset from slim models"
]
},
{
"metadata": {
"id": "uhfASQkoDNMH",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 57
},
{
"item_id": 58
}
],
"base_uri": "https://localhost:8080/",
"height": 68
},
"outputId": "c9c4256d-2248-444d-f853-7108c89e3ade",
"executionInfo": {
"status": "ok",
"timestamp": 1519819660617,
"user_tz": -480,
"elapsed": 140029,
"user": {
"displayName": "michael lin",
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg",
"userId": "111539764795298113840"
}
}
},
"cell_type": "code",
"source": [
"%cd $SLIM\n",
"import tensorflow as tf\n",
"from datasets import dataset_utils\n",
"\n",
"url = \"http://download.tensorflow.org/data/flowers.tar.gz\"\n",
"flowers_data_dir = '/tmp/flowers'\n",
"\n",
"if not tf.gfile.Exists(flowers_data_dir):\n",
" tf.gfile.MakeDirs(flowers_data_dir)\n",
"\n",
"dataset_utils.download_and_uncompress_tarball(url, flowers_data_dir) "
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/models/research/slim\n",
">> Downloading flowers.tar.gz 100.0%\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Successfully downloaded flowers.tar.gz 228649660 bytes.\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "wEaA4O3hDjGR",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 1
}
],
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "c318be8f-a4a2-43b0-ef81-a873c7eb3f9c",
"executionInfo": {
"status": "ok",
"timestamp": 1519819842687,
"user_tz": -480,
"elapsed": 1094,
"user": {
"displayName": "michael lin",
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg",
"userId": "111539764795298113840"
}
}
},
"cell_type": "code",
"source": [
"%cd $SLIM\n",
"from preprocessing import inception_preprocessing\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow.contrib import slim\n",
"\n",
"\n",
"def load_batch(dataset, batch_size=32, height=299, width=299, is_training=False):\n",
" \"\"\"Loads a single batch of data.\n",
" \n",
" Args:\n",
" dataset: The dataset to load.\n",
" batch_size: The number of images in the batch.\n",
" height: The size of each image after preprocessing.\n",
" width: The size of each image after preprocessing.\n",
" is_training: Whether or not we're currently training or evaluating.\n",
" \n",
" Returns:\n",
" images: A Tensor of size [batch_size, height, width, 3], image samples that have been preprocessed.\n",
" images_raw: A Tensor of size [batch_size, height, width, 3], image samples that can be used for visualization.\n",
" labels: A Tensor of size [batch_size], whose values range between 0 and dataset.num_classes.\n",
" \"\"\"\n",
" data_provider = slim.dataset_data_provider.DatasetDataProvider(\n",
" dataset, common_queue_capacity=32,\n",
" common_queue_min=8)\n",
" image_raw, label = data_provider.get(['image', 'label'])\n",
" \n",
" # Preprocess image for usage by Inception.\n",
" image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training)\n",
" \n",
" # Preprocess the image for display purposes.\n",
" image_raw = tf.expand_dims(image_raw, 0)\n",
" image_raw = tf.image.resize_images(image_raw, [height, width])\n",
" image_raw = tf.squeeze(image_raw)\n",
"\n",
" # Batch it up.\n",
" images, images_raw, labels = tf.train.batch(\n",
" [image, image_raw, label],\n",
" batch_size=batch_size,\n",
" num_threads=1,\n",
" capacity=2 * batch_size)\n",
" \n",
" return images, images_raw, labels"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/models/research/slim\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "43_I3sJwGALj",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"## runtime params\n",
"VALIDATION_INTERVAL = 1"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "0YmY9TfhFwDB",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"train_step = slim.learning.train_step\n",
"\n",
"# slim.learning.train(train_step_fn=)\n",
"def train_step_fn(sess, train_op, global_step, train_step_kwargs):\n",
" \"\"\"\n",
" slim.learning.train_step():\n",
" train_step_kwargs = {summary_writer:, should_log:, should_stop:}\n",
" \"\"\"\n",
" if hasattr(train_step_fn, 'step'):\n",
" train_step_fn.step += 1 # or use global_step.eval(session=sess)\n",
" else:\n",
" train_step_fn.step = global_step.eval(sess)\n",
" \n",
" # calc training losses\n",
" total_loss, should_stop = train_step(sess, train_op, global_step, train_step_kwargs)\n",
" \n",
" \n",
" # validate on interval\n",
" if train_step_fn.step % VALIDATION_INTERVAL == 0:\n",
" np_train_loss, np_val_loss, _ = sess.run([train_loss, val_loss, summary_validation_delta])\n",
" print(\">> global step {}: train={} validation={} delta={}\".format(train_step_fn.step, \n",
" np_train_loss, np_val_loss, np_val_loss-np_train_loss))\n",
" \n",
"\n",
" return [total_loss, should_stop]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ZZr9q1KbDb8u",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def my_cnn(images, num_classes, is_training): # is_training is not used...\n",
" with slim.arg_scope([slim.conv2d, slim.fully_connected], reuse=tf.AUTO_REUSE):\n",
" with slim.arg_scope([slim.max_pool2d], kernel_size=[3, 3], stride=2):\n",
" net = slim.conv2d(images, 64, [5, 5], scope=\"conv1\")\n",
" net = slim.max_pool2d(net)\n",
" net = slim.conv2d(net, 64, [5, 5], scope=\"conv2\")\n",
" net = slim.max_pool2d(net)\n",
" net = slim.flatten(net)\n",
" net = slim.fully_connected(net, 192, scope=\"fc1\")\n",
" net = slim.fully_connected(net, num_classes, activation_fn=None, scope=\"fc2\") \n",
" return net"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "xxccWeNWLcWM",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## training loop"
]
},
{
"metadata": {
"id": "Tx0FkgqPDUMX",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"output_extras": [
{
"item_id": 13
}
],
"base_uri": "https://localhost:8080/",
"height": 323
},
"outputId": "c3846c75-4e0d-44b8-d3b9-8a6a51533def",
"executionInfo": {
"status": "ok",
"timestamp": 1519821210448,
"user_tz": -480,
"elapsed": 95100,
"user": {
"displayName": "michael lin",
"photoUrl": "//lh3.googleusercontent.com/-etfWG7MvQwk/AAAAAAAAAAI/AAAAAAAAADM/BxW0OLTdkjI/s50-c-k-no/photo.jpg",
"userId": "111539764795298113840"
}
}
},
"cell_type": "code",
"source": [
"from datasets import flowers\n",
"\n",
"# This might take a few minutes.\n",
"train_dir = LOG_DIR = '/tmp/flowers_model/'\n",
"print('Will save model to %s' % train_dir)\n",
"\n",
"tboard.launch_tensorboard(bin_dir=\"/tmp\", log_dir=LOG_DIR)\n",
"\n",
"with tf.Graph().as_default():\n",
" tf.logging.set_verbosity(tf.logging.INFO)\n",
"\n",
" dataset = flowers.get_split('train', flowers_data_dir)\n",
" images, _, labels = load_batch(dataset)\n",
" \n",
" val_dataset = flowers.get_split('validation', flowers_data_dir)\n",
" val_images, _, val_labels = load_batch(val_dataset)\n",
" \n",
" \n",
" # Create the model:\n",
" logits = my_cnn(images, num_classes=dataset.num_classes, is_training=True)\n",
" val_logits = my_cnn(val_images, num_classes=dataset.num_classes, is_training=False)\n",
" \n",
" # Specify the `train` loss function:\n",
" one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)\n",
" train_loss = tf.losses.softmax_cross_entropy (one_hot_labels, logits)\n",
" total_loss = slim.losses.get_total_loss()\n",
" \n",
" # Specify the `validation` loss function:\n",
" val_one_hot_labels = slim.one_hot_encoding(val_labels, dataset.num_classes)\n",
" val_loss = tf.losses.softmax_cross_entropy (val_one_hot_labels, val_logits, \n",
" loss_collection=\"validation\")\n",
" \n",
"\n",
" # Create some summaries to visualize the training process:\n",
" tf.summary.scalar('train/Total_Loss', total_loss)\n",
" tf.summary.scalar('validation/Validation_Loss', val_loss)\n",
" summary_validation_delta = tf.summary.scalar('validation/Validation_Delta', (val_loss - train_loss))\n",
" \n",
" # Specify the optimizer and create the train op:\n",
" optimizer = tf.train.AdamOptimizer(learning_rate=0.01)\n",
" train_op = slim.learning.create_train_op(total_loss, optimizer)\n",
"\n",
" if True: reset_tensorboard(LOG_DIR)\n",
" \n",
" # Run the training:\n",
" final_loss = slim.learning.train(\n",
" train_op,\n",
" train_step_fn=train_step_fn,\n",
" logdir=train_dir,\n",
" number_of_steps=1, # For speed, we just do 1 epoch\n",
" save_summaries_secs=1)\n",
" \n",
" print('Finished training. Final batch loss %d' % final_loss)"
],
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": [
"Will save model to /tmp/flowers_model/\n",
"ngrok installed\n",
"status: tensorboard=True, ngrok=True\n",
"tensorboard url= https://7ca94201.ngrok.io\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Starting Session.\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/flowers_model/model.ckpt\n",
"INFO:tensorflow:Starting Queues.\n",
"INFO:tensorflow:Recording summary at step 0.\n",
"INFO:tensorflow:Recording summary at step 0.\n",
"INFO:tensorflow:global step 1: loss = 1.6314 (51.359 sec/step)\n",
"INFO:tensorflow:Recording summary at step 1.\n",
"INFO:tensorflow:Recording summary at step 1.\n",
">> global step 0: train=5087.9873046875 validation=7657.87646484375 delta=2569.88916015625\n",
"INFO:tensorflow:Stopping Training.\n",
"INFO:tensorflow:Finished training! Saving model to disk.\n",
"Finished training. Final batch loss 1\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "QivJ4FtsE_z4",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment