Skip to content

Instantly share code, notes, and snippets.

@mwilliammyers
Created July 3, 2019 05:20
Show Gist options
  • Save mwilliammyers/ce9adca6b0f7d30f31f5bce59bcbfe44 to your computer and use it in GitHub Desktop.
Save mwilliammyers/ce9adca6b0f7d30f31f5bce59bcbfe44 to your computer and use it in GitHub Desktop.
TensorFlow eager execution demo of a simple regression model
import logging
import numpy as np
import tensorflow as tf
logging.getLogger().setLevel(logging.INFO)
train_data = np.load(
"/mnt/pccfs/not_backed_up/data/eve-embeddings-prod/ccc_train.npy")
train_features = train_data.item()["features"]
train_labels = train_data.item()["labels"]
test_data = np.load(
"/mnt/pccfs/not_backed_up/data/eve-embeddings-prod/ccc_evaluation.npy")
test_features = test_data.item()["features"]
test_labels = test_data.item()["labels"]
DIM = 512
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={
"previous_response": train_features[:, :DIM],
"text": train_features[:, DIM:]
},
y=train_labels,
batch_size=2,
num_epochs=None,
shuffle=True)
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={
"previous_response": test_features[:, :DIM],
"text": test_features[:, DIM:]
},
y=test_labels,
num_epochs=None,
shuffle=True)
# samples = np.array([8., 9.])
# predict_input_fn = tf.estimator.inputs.numpy_input_fn(
# x={"f1": samples}, num_epochs=1, shuffle=False)
model = tf.estimator.LinearRegressor(
feature_columns=[
tf.feature_column.numeric_column("previous_response", shape=DIM),
tf.feature_column.numeric_column("text", shape=DIM),
],
label_dimension=DIM,
# optimizer=tf.train.AdamOptimizer(),
model_dir='./output')
# model.train(input_fn=train_input_fn, steps=50000)
eval_results = model.evaluate(input_fn=test_input_fn, steps=1000)
average_loss = eval_results["average_loss"]
print(f"Average loss in testing: {average_loss:.4f}")
# predictions = list(model.predict(input_fn=predict_input_fn))
#
# for input, p in zip(samples, predictions):
# v = p["predictions"][0]
# print(f"{input} -> {v:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment