Skip to content

Instantly share code, notes, and snippets.

Created July 16, 2016 17:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/a8545170cc952aed56bb8abd798b04d7 to your computer and use it in GitHub Desktop.
Save anonymous/a8545170cc952aed56bb8abd798b04d7 to your computer and use it in GitHub Desktop.
sample code for skflow
from sklearn import metrics, cross_validation
from tensorflow.contrib import learn
iris = learn.datasets.load_iris()
X_train, X_test, y_train, y_test = cross_validation.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
classifier = learn.DNNClassifier(hidden_units=[10,20,10], n_classes=3)
classifier.fit(x=X_train, y=y_train, steps=200)
print(metrics.accuracy_score(classifier.predict(X_test), y_test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment