Skip to content

Instantly share code, notes, and snippets.

@Waffleboy
Forked from ilblackdragon/digits.py
Last active June 27, 2016 17:24
Show Gist options
  • Save Waffleboy/3f1d83f2e2c63046db76ec94306cc8bc to your computer and use it in GitHub Desktop.
Save Waffleboy/3f1d83f2e2c63046db76ec94306cc8bc to your computer and use it in GitHub Desktop.
Scikit Flow - Digits example
import random
from sklearn import datasets, cross_validation, metrics
import tensorflow as tf
from tensorflow.contrib import learn as skflow
random.seed(42)
# Load dataset and split it into train / test subsets.
digits = datasets.load_digits()
X = digits.images
y = digits.target
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y,
test_size=0.2, random_state=42)
# TensorFlow model using Scikit Flow ops
def conv_model(X, y):
X = tf.expand_dims(X, 3)
features = tf.reduce_max(skflow.ops.conv2d(X, 12, [3, 3]), [1, 2])
features = tf.reshape(features, [-1, 12])
return skflow.models.logistic_regression(features, y)
# Create a classifier, train and predict.
classifier = skflow.TensorFlowEstimator(model_fn=conv_model, n_classes=10,
steps=500, learning_rate=0.05,
batch_size=128)
classifier.fit(X_train, y_train)
score = metrics.accuracy_score(y_test,classifier.predict(X_test))
print('Accuracy: %f' % score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment