Skip to content

Instantly share code, notes, and snippets.

@jenyckee
Created December 19, 2018 18:00
Show Gist options
  • Save jenyckee/fc9fa36ae8a67219ca0adf1a8fc281e4 to your computer and use it in GitHub Desktop.
Save jenyckee/fc9fa36ae8a67219ca0adf1a8fc281e4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import os\n",
"import tensorflow as tf\n",
"from tensorflow.contrib.training.python.training import hparam\n",
"import numpy as np\n",
"import trainer.data as data\n",
"import trainer.model as model"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"def solution(features, labels, mode):\n",
" \"\"\"Returns an EstimatorSpec that is constructed using the solution that you have to write below.\"\"\"\n",
" # Input Layer (a batch of images that have 64x64 pixels and are RGB colored (3)\n",
" input_layer = tf.reshape(features[\"x\"], [-1, 64, 64, 3])\n",
" # TODO: Code of your solution\n",
" # Convolutional Layer #1\n",
" conv1 = tf.layers.conv2d(\n",
" inputs=input_layer,\n",
" filters=32,\n",
" kernel_size=[5, 5],\n",
" padding=\"same\",\n",
" activation=tf.nn.relu)\n",
"\n",
" # Pooling Layer #1\n",
" pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)\n",
" print(pool1.shape)\n",
" # Convolutional Layer #2 and Pooling Layer #2\n",
" conv2 = tf.layers.conv2d(\n",
" inputs=pool1,\n",
" filters=64,\n",
" kernel_size=[5, 5],\n",
" padding=\"same\",\n",
" activation=tf.nn.relu)\n",
" print(conv2.shape)\n",
" pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)\n",
"\n",
" # Dense Layer\n",
" pool2_flat = tf.reshape(pool2, [-1, 8 * 8 * 64])\n",
" print(pool2_flat.shape)\n",
" dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)\n",
" print(dense.shape)\n",
" dropout = tf.layers.dropout(\n",
" inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)\n",
"\n",
" # Logits Layer\n",
" logits = tf.layers.dense(inputs=dropout, units=4)\n",
"\n",
" predictions = {\n",
" # Generate predictions (for PREDICT and EVAL mode)\n",
" \"classes\": tf.argmax(input=logits, axis=1),\n",
" # Add `softmax_tensor` to the graph. It is used for PREDICT and by the\n",
" # `logging_hook`.\n",
" \"probabilities\": tf.nn.softmax(logits, name=\"softmax_tensor\")\n",
" }\n",
"\n",
" if mode == tf.estimator.ModeKeys.PREDICT:\n",
" # TODO: return tf.estimator.EstimatorSpec with prediction values of all classes\n",
" return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n",
"\n",
" # Calculate Loss (for both TRAIN and EVAL modes)\n",
" loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\n",
"\n",
" if mode == tf.estimator.ModeKeys.TRAIN:\n",
" # TODO: Let the model train here\n",
" # TODO: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)\n",
" optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)\n",
" train_op = optimizer.minimize(\n",
" loss=loss,\n",
" global_step=tf.train.get_global_step())\n",
" return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)\n",
"\n",
" if mode == tf.estimator.ModeKeys.EVAL:\n",
" # The classes variable below exists of an tensor that contains all the predicted classes in a batch\n",
" # TODO: eval_metric_ops = {\"accuracy\": tf.metrics.accuracy(labels=labels, predictions=classes)}\n",
" # TODO: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)\n",
" eval_metric_ops = {\n",
" \"accuracy\": tf.metrics.accuracy(\n",
" labels=labels, predictions=predictions[\"classes\"])}\n",
" return tf.estimator.EstimatorSpec(\n",
" mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2000, 64, 64, 3)\n",
"(2000,)\n",
"INFO:tensorflow:Using default config.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /var/folders/tc/08g7j299597fg1rmqgyd0z1h0000gn/T/tmpao5ylhln\n",
"INFO:tensorflow:Using config: {'_model_dir': '/var/folders/tc/08g7j299597fg1rmqgyd0z1h0000gn/T/tmpao5ylhln', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x153c21be0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
]
}
],
"source": [
"(train_data, train_labels) = data.create_data_with_labels(\"data/train/\")\n",
"\n",
"# mnist = tf.contrib.learn.datasets.load_dataset(\"mnist\")\n",
"# train_data = mnist.train.images # Returns np.array\n",
"# train_labels = np.asarray(mnist.train.labels, dtype=np.int32)\n",
"\n",
"print(train_data.shape)\n",
"print(train_labels.shape)\n",
"train_input_fn = tf.estimator.inputs.numpy_input_fn(\n",
" x={\"x\": train_data},\n",
" y=train_labels,\n",
" batch_size=100,\n",
" num_epochs=None,\n",
" shuffle=True)\n",
"\n",
"estimator = tf.estimator.Estimator(model_fn=solution)\n",
"eval_steps = 5\n",
"steps_per_eval = int(get_training_steps() / eval_steps)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"# estimator.train(train_input_fn, steps=steps_per_eval)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment