Last active
February 6, 2019 20:54
-
-
Save shang-vikas/36176e3bed6f3234fd1c27465d8bec22 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## <font color=brown>Simple Graph Execution</font>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## import necessary stuff\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import os,sys\n", | |
"from tensorflow.keras.datasets import mnist\n", | |
"\n", | |
"import time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## load data of mnist from keras. alternatively you can load data from tf.contrib.learn.datasets.\n", | |
"(x_train,y_train),(x_test,y_test) = mnist.load_data()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_train = x_train.reshape(-1,784)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Building simple DNN Model and feeding the numpy/pandas data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Simple MNIST Model of Dense layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Using default config.\n", | |
"WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpap27xr75\n", | |
"INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpap27xr75', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fa4160c9080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" | |
] | |
} | |
], | |
"source": [ | |
"## defining the type of features columns to be used on model.\n", | |
"feature_column = [tf.feature_column.numeric_column(key='image',shape=(784,))]\n", | |
"\n", | |
"##defining the model\n", | |
"model = tf.estimator.DNNClassifier([100,100],n_classes=10,feature_columns=feature_column)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Defining the dataset flow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"'''parse function to be used. This function is needed to do \n", | |
"the preprocessing of data like reshaping ,\n", | |
"converting to tensors from numpy arrays ,one-hot encoding ,etc.'''\n", | |
"def _parse_(x,y):\n", | |
" x = tf.reshape(x,(784,)) #reshape to 784 as expected.\n", | |
" x = tf.cast(x,tf.float32) #cast to float32 as the weights are float32.\n", | |
" # y = tf.one_hot(y,10) #should be one hot encoded but tensorflow is throwing error .idky??\n", | |
" y = tf.cast(y,tf.int32) #cast to tensor of int32\n", | |
" return (dict({'image':x}),y) #return tuple of dict of feature name with value and label." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"##define the function that feeds the data to the model .\n", | |
"def train_input_fn(x_train,y_train,batch_size=64):\n", | |
" ##Here we are using dataset API.\n", | |
" '''\n", | |
" take the data from tensor_slices i.e. an array of datapoints in simple words.\n", | |
" '''\n", | |
" dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \n", | |
" \n", | |
" \n", | |
" dataset = dataset.map(lambda x,y:_parse_(x,y)).shuffle(buffer_size=128) \\\n", | |
" .batch(batch_size).make_one_shot_iterator()\n", | |
" \n", | |
" return dataset.get_next()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Training the model by feeding data through the above defined flow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Calling model_fn.\n", | |
"INFO:tensorflow:Done calling model_fn.\n", | |
"INFO:tensorflow:Create CheckpointSaverHook.\n", | |
"INFO:tensorflow:Graph was finalized.\n", | |
"INFO:tensorflow:Running local_init_op.\n", | |
"INFO:tensorflow:Done running local_init_op.\n", | |
"INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpap27xr75/model.ckpt.\n", | |
"INFO:tensorflow:loss = 8941.259, step = 0\n", | |
"INFO:tensorflow:global_step/sec: 240.438\n", | |
"INFO:tensorflow:loss = 101.23705, step = 100 (0.417 sec)\n", | |
"INFO:tensorflow:Saving checkpoints for 150 into /tmp/tmpap27xr75/model.ckpt.\n", | |
"INFO:tensorflow:Loss for final step: 125.124916.\n", | |
"---------18.832788705825806\n" | |
] | |
} | |
], | |
"source": [ | |
"t1 = time.time()\n", | |
"model.train(input_fn=lambda:train_input_fn(x_train,y_train,64),steps=150)\n", | |
"t2= time.time()\n", | |
"print('---------{}'.format(t2 - t1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t1 = time.time()\n", | |
"model.train(input_fn=lambda:train_input_fn(x_train,y_train,64),steps=251)\n", | |
"t2= time.time()\n", | |
"print('---------{}'.format(t2 - t1))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Train using the numpy input function**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Calling model_fn.\n", | |
"INFO:tensorflow:Done calling model_fn.\n", | |
"INFO:tensorflow:Create CheckpointSaverHook.\n", | |
"INFO:tensorflow:Graph was finalized.\n", | |
"INFO:tensorflow:Restoring parameters from /tmp/tmpbqg7sg1j/model.ckpt-102\n", | |
"INFO:tensorflow:Running local_init_op.\n", | |
"INFO:tensorflow:Done running local_init_op.\n", | |
"INFO:tensorflow:Saving checkpoints for 102 into /tmp/tmpbqg7sg1j/model.ckpt.\n", | |
"INFO:tensorflow:loss = 174.48169, step = 102\n", | |
"INFO:tensorflow:global_step/sec: 642.66\n", | |
"INFO:tensorflow:loss = 149.84337, step = 202 (0.157 sec)\n", | |
"INFO:tensorflow:Saving checkpoints for 203 into /tmp/tmpbqg7sg1j/model.ckpt.\n", | |
"INFO:tensorflow:Loss for final step: 149.84337.\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x7f5f2706c208>" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.train(input_fn=tf.estimator.inputs.numpy_input_fn(dict({'image':x_train}),\n", | |
" np.array(y_train,np.int32),shuffle=True),steps=101)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment