Skip to content

Instantly share code, notes, and snippets.

@ilblackdragon
Last active January 26, 2017 18:22
Show Gist options
  • Save ilblackdragon/1ff35124459cdf16bc01 to your computer and use it in GitHub Desktop.
Save ilblackdragon/1ff35124459cdf16bc01 to your computer and use it in GitHub Desktop.
def dnn_tanh(features, target):
target = tf.one_hot(target, 2, 1.0, 0.0)
logits = layers.stack(features, layers.fully_connected, [10, 20, 10],
activation_fn=tf.tanh)
prediction, loss = learn.models.logistic_regression(logits, target)
train_op = layers.optimize_loss(loss,
tf.contrib.framework.get_global_step(), optimizer='SGD', learning_rate=0.05)
return tf.argmax(prediction, dimension=1), loss, train_op
random.seed(42)
classifier = learn.Estimator(model_fn=dnn_tanh)
classifier.fit(X_train, y_train, batch_size=128, steps=100)
print("Accuracy: %f" % score)
# Outputs: Accuracy: 0.692737430168
@maininformer
Copy link

Hello, I was reading and following your post on Medium. Thank you for that, but it seems that skflow.models has no methods called logistic_classifier, or at least not anymore. maybe logistic_regression instead?

@vinayprabhu
Copy link

Just a note. If you were to try to get the cross val score by:
from sklearn.cross_validation import cross_val_score scores = cross_val_score(classifier, X, y,cv=4),
you will get the following error:

TypeError: If no scoring is specified, the estimator passed should have a 'score' method. The estimator TensorFlowEstimator(batch_size=128, class_weight=None,
          continue_training=False, early_stopping_rounds=None,
          keep_checkpoint_every_n_hours=10000, learning_rate=0.05,
          max_to_keep=5, model_fn=<function dnn_tanh at 0x7f41114e4c80>,
          n_classes=2, num_cores=4, optimizer='SGD', steps=500,
          tf_master='', tf_random_seed=42, verbose=1) does not.

Solution? Add the scoring method as:
scores = cross_val_score(classifier, X, y,cv=4,scoring="accuracy")

I mention this here because you will not encounter this situation with skflow.TensorFlowDNNClassifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment