Created
December 9, 2017 03:33
-
-
Save anonymous/0ba3e190a729b7a059c8013b69a1e395 to your computer and use it in GitHub Desktop.
Simple 2-layer MLP to clarify a TensorFlow API question. Uses the StatOil/C-CORE dataset from Kaggle.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
def normalize(arr, axis): | |
means = arr.mean(axis=axis, keepdims=True) | |
devs = arr.std(axis=axis, keepdims=True) | |
arr -= means | |
arr /= devs | |
df = pd.read_json('data/processed/train.json') | |
images = np.array([df['band_1'], df['band_2']], dtype = np.float32) | |
images = np.moveaxis(images, 0, -1).reshape(-1, 75, 75, 2) | |
labels = df['is_iceberg'].as_matrix() | |
normalize(images, axis=(0, 1, 2)) | |
seed = 42 | |
tf.set_random_seed(seed) | |
np.random.seed(seed) | |
tf.reset_default_graph() | |
X = tf.placeholder(tf.float32, shape=(None, 75, 75, 2)) | |
y = tf.placeholder(tf.int32, shape=(None)) | |
fc0 = tf.layers.flatten(X) | |
fc1 = tf.layers.dense(fc0, 512, activation=tf.nn.relu) | |
fc2 = tf.layers.dense(fc1, 256, activation=tf.nn.relu) | |
# Setup 1: | |
out = tf.layers.dense(fc2, 2) | |
loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=out) | |
correct_flags = tf.nn.in_top_k(out, y, 1) | |
# Setup 2: | |
# out = tf.layers.dense(fc2, 1, activation=tf.sigmoid) | |
# loss = tf.losses.log_loss(labels=y, predictions=out) | |
# correct_flags = tf.equal(y, tf.cast(tf.round(out), tf.int32)) | |
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) | |
accuracy = tf.reduce_mean(tf.cast(correct_flags, tf.float32)) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
for e in range(100): | |
_, loss_val, acc_val = sess.run([train_step, loss, accuracy], feed_dict={X: images, y: labels}) | |
print('Epoch: {} Loss: {} Accuracy: {}'.format(e + 1, loss_val, acc_val)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment