Skip to content

Instantly share code, notes, and snippets.

@krishpop
Last active November 21, 2017 23:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krishpop/dda70498d31b46efdb78f495a6754ab8 to your computer and use it in GitHub Desktop.
Save krishpop/dda70498d31b46efdb78f495a6754ab8 to your computer and use it in GitHub Desktop.
TF Tutorial
class DataSet:
"""Base data set class
"""
def __init__(self, shuffle=True, labeled=True, **data_dict):
assert '_data' in data_dict
if labeled:
assert '_labels' in data_dict
assert data_dict['_data'].shape[0] == data_dict['_labels'].shape[0]
self._labeled = labeled
self._shuffle = shuffle
self.__dict__.update(data_dict)
self._num_samples = self._data.shape[0]
self._index_in_epoch = 0
self._epochs_trained = 0
self._batch_number = 0
if self._shuffle:
self._shuffle_data()
def __len__(self):
return len(self._data) + len(self._test_data)
@property
def epochs_trained(self):
return self._epochs_trained
@epochs_trained.setter
def epochs_trained(self, new_epochs_trained):
self._epochs_trained = new_epochs_trained
@property
def batch_number(self):
return self._batch_number
@property
def index_in_epoch(self):
return self._index_in_epoch
@property
def num_samples(self):
return self._num_samples
@property
def data(self):
return self._data
@property
def labels(self):
return self._labels
@property
def labeled(self):
return self._labeled
@property
def test_data(self):
return self._test_data
@property
def test_labels(self):
return self._test_labels
@classmethod
def load(cls, filename):
data_dict = np.load(filename)
labeled = data_dict['_labeled']
return cls(labeled=labeled, **data_dict)
def save(self, filename):
data_dict = self.__dict__
np.savez_compressed(filename, **data_dict)
def _shuffle_data(self):
shuffled_idx = np.arange(self._num_samples)
np.random.shuffle(shuffled_idx)
self._data = self._data[shuffled_idx]
if self._labeled:
self._labels = self._labels[shuffled_idx]
def next_batch(self, batch_size):
assert batch_size <= self._num_samples
start = self._index_in_epoch
if start + batch_size > self._num_samples:
self._epochs_trained += 1
self._batch_number = 0
data_batch = self._data[start:]
if self._labeled:
labels_batch = self._labels[start:]
remaining = batch_size - (self._num_samples - start)
if self._shuffle:
self._shuffle_data()
start = 0
data_batch = np.concatenate([data_batch, self._data[:remaining]],
axis=0)
if self._labeled:
labels_batch = np.concatenate([labels_batch,
self._labels[:remaining]],
axis=0)
self._index_in_epoch = remaining
else:
data_batch = self._data[start:start + batch_size]
if self._labeled:
labels_batch = self._labels[start:start + batch_size]
self._index_in_epoch = start + batch_size
self._batch_number += 1
batch = (data_batch, labels_batch) if self._labeled else data_batch
return batch
"""
## example
import pandas as pd
from sklearn.model_selection import train_test_split
filename = 'filename.csv'
csv_data = pd.read_csv(filename)
train, test = train_test_split(csv_data, train_size=.9)
data_dict = {'_data': train, '_test_data': test}
data = DataSet(labeled=False, **data_dict)
num_steps = 1000
for _ in range(num_steps):
batch = data.next_batch(100)
# do something with batch
"""
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"%matplotlib inline\n",
"\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tensorflow.examples.tutorials.mnist import input_data\n",
"\n",
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(mnist.train.labels[1])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(784,)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist.train.images[1].shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x11cb3fa50>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADhFJREFUeJzt3V2MVPUZx/HfU9Eb9EJZuhLFxRqD\nUS/QrKYXSDRWFGMC3BhfYmiqrDGaFO1F8SXWBEXTVCvcoGskYuNbA2wkBquWNECThvBmfdkFtQYF\ngiyIiRovrO7Tizk0q+75n2HmzJxZnu8n2ezMeebMPB73x5kz/znnb+4uAPH8rOoGAFSD8ANBEX4g\nKMIPBEX4gaAIPxAU4QeCIvxAUIQfCGpCO1/MzPg6IdBi7m71PK6pPb+ZXWNmu83sIzNb3MxzAWgv\na/S7/WZ2gqQPJF0laZ+krZJudPfBxDrs+YEWa8ee/1JJH7n7x+7+raSXJc1t4vkAtFEz4T9D0t5R\n9/dly37AzPrMbJuZbWvitQCUrOUf+Ll7v6R+ibf9QCdpZs+/X9LUUffPzJYBGAeaCf9WSeea2dlm\ndpKkGyStK6ctAK3W8Nt+d//OzO6S9IakEyStdPf3S+sMQEs1PNTX0ItxzA+0XFu+5ANg/CL8QFCE\nHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ\nhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIan6JYkM9sj6StJ30v6zt17\ny2gK7dPT05Os33bbbcn6/fffn6ynZoE2S08mOzQ0lKw/8MADyfrAwECyHl1T4c9c4e6HS3geAG3E\n234gqGbD75LeNLPtZtZXRkMA2qPZt/0z3X2/mf1c0ltmtsvdN41+QPaPAv8wAB2mqT2/u+/Pfg9L\nGpB06RiP6Xf3Xj4MBDpLw+E3s4lmdsrR25JmS3qvrMYAtFYzb/u7JQ1kwzUTJL3o7n8rpSsALWep\ncdjSX8ysfS8WyOTJk3Nr9957b3Ldm2++OVmfNGlSsl40Vt/MOH/R3+bevXuT9UsuuSS3dvjw8Ts6\n7e7pDZthqA8IivADQRF+ICjCDwRF+IGgCD8QFEN940DRabNLlizJrRX9/231cNuhQ4eS9ZSurq5k\nfdq0acn64OBgbu2CCy5opKVxgaE+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4/zjwNatW5P1iy++\nOLfW7Dh/aqxckq644opkvZlTZ2fOnJmsb9y4MVlP/bdPmFDGhas7E+P8AJIIPxAU4QeCIvxAUIQf\nCIrwA0ERfiAoxvk7wHnnnZesF43zf/7557m1ovPpi8bh77777mR90aJFyfrSpUtza59++mly3SJF\nf7sjIyO5tTvuuCO5bn9/f0M9dQLG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIXj/Ga2UtJ1kobd\n/cJs2WmSXpE0TdIeSde7+xeFL8Y4f0OKvgeQGqtvdirqvr6+ZH3FihXJemqa7B07diTXnT9/frK+\nevXqZD31t3366acn1x3PU3iXOc7/nKRrfrRssaQN7n6upA3ZfQDjSGH43X2TpCM/WjxX0qrs9ipJ\n80ruC0CLNXrM3+3uB7Lbn0nqLqkfAG3S9IXM3N1Tx/Jm1icpfeAIoO0a3fMfNLMpkpT9Hs57oLv3\nu3uvu/c2+FoAWqDR8K+TtCC7vUDSq+W0A6BdCsNvZi9J+pek6Wa2z8xulfSYpKvM7ENJv8ruAxhH\nCo/53f3GnNKVJfeCHLt27arstYuuB7B79+5kPXWtgaJrBSxenB5BLppzoJXffzge8A0/ICjCDwRF\n+IGgCD8QFOEHgiL8QFDH7zzFgcyaNSu3VnQ6cNFQ3tDQULI+ffr0ZH3Lli25tcmTJyfXLTrdvKj3\nOXPmJOvRsecHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY5z8O3HTTTbm1hQsXJtctOi22jku7J+up\nsfxmTsmVpOXLlyfrRZcGj449PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTj/ca5onL7K9Tdv3pxc\n95577knWGcdvDnt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiqcJzfzFZKuk7SsLtfmC17SNJCSUcv\nnH6fu69vVZNIe/HFF3NrPT09yXW7urqS9aLr/k+cODFZT3nwwQeTdcbxW6uePf9zkq4ZY/mf3X1G\n9kPwgXGmMPzuvknSkTb0AqCNmjnmv8vM3jGzlWZ2amkdAWiLRsO/QtI5kmZIOiDp8bwHmlmfmW0z\ns20NvhaAFmgo/O5+0N2/d/cRSc9IujTx2H5373X33kabBFC+hsJvZlNG3Z0v6b1y2gHQLvUM9b0k\n6XJJXWa2T9IfJF1uZjMkuaQ9km5vYY8AWsCaPV/7mF7MrH0vhlIUjfM//PDDyfq8efNyazt37kyu\nO2fOnGS96Lr+Ubl7ekKEDN/wA4Ii/EBQhB8IivADQRF+ICjCDwTFUF+dUlNNHzp0KLcW3euvv55b\nu/rqq5PrFl26+8knn2yop+MdQ30Akgg/EBThB4Ii/EBQhB8IivADQRF+ICim6M7MmjUrWX/88dwr\nlWnXrl3JdW+55ZaGejoePPLII7m12bNnJ9edPn162e1gFPb8QFCEHwiK8ANBEX4gKMIPBEX4gaAI\nPxBUmHH+1Pn4kvTUU08l68PDw7m1yOP4RVN0P/3007k1s7pOO0eLsOcHgiL8QFCEHwiK8ANBEX4g\nKMIPBEX4gaAKx/nNbKqk5yV1S3JJ/e6+zMxOk/SKpGmS9ki63t2/aF2rzZk/f36yXnTu+MaNG8ts\nZ9womqJ7zZo1yXpquxbNGVF0nQQ0p549/3eSfufu50v6paQ7zex8SYslbXD3cyVtyO4DGCcKw+/u\nB9x9R3b7K0lDks6QNFfSquxhqyTNa1WTAMp3TMf8ZjZN0kWStkjqdvcDWekz1Q4LAIwTdX+338xO\nlrRG0iJ3/3L097Ld3fPm4TOzPkl9zTYKoFx17fnN7ETVgv+Cu6/NFh80sylZfYqkMc98cfd+d+91\n994yGgZQjsLwW20X/6ykIXd/YlRpnaQF2e0Fkl4tvz0ArVI4RbeZzZS0WdK7kkayxfepdtz/V0ln\nSfpEtaG+IwXPVdkU3UVDVkNDQ8n64OBgbu3RRx9t6rm3b9+erBfp6enJrV122WXJdYuGQOfNS3+O\nW3Raburva9myZcl1i6boxtjqnaK78Jjf3f8pKe/JrjyWpgB0Dr7hBwRF+IGgCD8QFOEHgiL8QFCE\nHwiqcJy/1BercJy/yOrVq5P11Hh3M2PdkrRz585kvchZZ52VW5s0aVJy3WZ7L1o/NUX38uXLk+se\nPnw4WcfY6h3nZ88PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzp8pmsJ7/fr1ubXe3vRFikZGRpL1\nVo61F637zTffJOtFl89eunRpsj4wMJCso3yM8wNIIvxAUIQfCIrwA0ERfiAowg8ERfiBoBjnr1NX\nV1dubcmSJU09d19fejaztWvXJuvNnPdedO18pskefxjnB5BE+IGgCD8QFOEHgiL8QFCEHwiK8ANB\nFY7zm9lUSc9L6pbkkvrdfZmZPSRpoaRD2UPvc/f8k941vsf5gfGi3nH+esI/RdIUd99hZqdI2i5p\nnqTrJX3t7n+qtynCD7ReveGfUMcTHZB0ILv9lZkNSTqjufYAVO2YjvnNbJqkiyRtyRbdZWbvmNlK\nMzs1Z50+M9tmZtua6hRAqer+br+ZnSxpo6RH3H2tmXVLOqza5wBLVDs0+E3Bc/C2H2ix0o75JcnM\nTpT0mqQ33P2JMerTJL3m7hcWPA/hB1qstBN7rHZp2GclDY0OfvZB4FHzJb13rE0CqE49n/bPlLRZ\n0ruSjl6D+j5JN0qaodrb/j2Sbs8+HEw9F3t+oMVKfdtfFsIPtB7n8wNIIvxAUIQfCIrwA0ERfiAo\nwg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRVeAHPkh2W9Mmo+13Zsk7Uqb11al8SvTWq\nzN566n1gW8/n/8mLm21z997KGkjo1N46tS+J3hpVVW+87QeCIvxAUFWHv7/i10/p1N46tS+J3hpV\nSW+VHvMDqE7Ve34AFakk/GZ2jZntNrOPzGxxFT3kMbM9Zvaumb1d9RRj2TRow2b23qhlp5nZW2b2\nYfZ7zGnSKurtITPbn227t83s2op6m2pm/zCzQTN738x+my2vdNsl+qpku7X9bb+ZnSDpA0lXSdon\naaukG919sK2N5DCzPZJ63b3yMWEzmyXpa0nPH50Nycz+KOmIuz+W/cN5qrv/vkN6e0jHOHNzi3rL\nm1n616pw25U543UZqtjzXyrpI3f/2N2/lfSypLkV9NHx3H2TpCM/WjxX0qrs9irV/njaLqe3juDu\nB9x9R3b7K0lHZ5audNsl+qpEFeE/Q9LeUff3qbOm/HZJb5rZdjPrq7qZMXSPmhnpM0ndVTYzhsKZ\nm9vpRzNLd8y2a2TG67Lxgd9PzXT3iyXNkXRn9va2I3ntmK2ThmtWSDpHtWncDkh6vMpmspml10ha\n5O5fjq5Vue3G6KuS7VZF+PdLmjrq/pnZso7g7vuz38OSBlQ7TOkkB49Okpr9Hq64n/9z94Pu/r27\nj0h6RhVuu2xm6TWSXnD3tdniyrfdWH1Vtd2qCP9WSeea2dlmdpKkGyStq6CPnzCzidkHMTKziZJm\nq/NmH14naUF2e4GkVyvs5Qc6ZebmvJmlVfG267gZr9297T+SrlXtE///SLq/ih5y+vqFpH9nP+9X\n3Zukl1R7G/hf1T4buVXSJEkbJH0o6e+STuug3v6i2mzO76gWtCkV9TZTtbf070h6O/u5tuptl+ir\nku3GN/yAoPjADwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUP8DUODl2qszuRAAAAAASUVORK5C\nYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11ddd3510>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(mnist.train.images[1].reshape((28,28)), cmap=plt.cm.gray)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$$y = xW + b$$\n",
"\n",
"$$ \\mathcal{L}(\\dot{y}, y) = -\\sum y * \\log(\\dot{y}) + (1 - y) * (1-\\log(\\dot{y})) $$"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10000, 784)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist.test.images.shape"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"sess.close()\n",
"tf.reset_default_graph()"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"\n",
"np.random.seed(52)\n",
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)\n",
"\n",
"x = tf.placeholder(shape=[None,784], dtype=tf.float32)\n",
"y_ = tf.placeholder(shape=[None,10], dtype=tf.float32)\n",
"\n",
"w = tf.get_variable('w', shape=[784,10], dtype=tf.float32, initializer=tf.zeros_initializer())\n",
"b = tf.get_variable('b', shape=[10], dtype=tf.float32, initializer=tf.zeros_initializer())\n",
"\n",
"y = tf.add(tf.matmul(x, w, name='wx'), b, name='y')\n",
"\n",
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_, name='loss'), name='avg_loss')\n",
"optimizer = tf.train.GradientDescentOptimizer(0.5)\n",
"opt_op = optimizer.minimize(loss)\n",
"\n",
"tf.global_variables_initializer().run()\n",
"\n",
"batch_size = 100\n",
"\n",
"for step in range(1000):\n",
" batch_x, batch_y = mnist.train.next_batch(100)\n",
" _ = sess.run(opt_op, feed_dict={x: batch_x, y_: batch_y})\n",
"\n",
"predictions = sess.run(y, feed_dict={x: mnist.test.images})\n",
"\n",
"actual_predictions = np.argmax(predictions, axis=1)\n",
"nsamples, ndims = mnist.test.labels.shape\n",
"acc = np.sum(np.equal(actual_predictions, np.argmax(mnist.test.labels, axis=1))) / float(nsamples)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.91990000000000005"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acc"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@krishpop
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment