Skip to content

Instantly share code, notes, and snippets.

@protoget
Last active December 8, 2017 02:28
Show Gist options
  • Save protoget/2cf2b530bc300f209473374cf02ad829 to your computer and use it in GitHub Desktop.
Save protoget/2cf2b530bc300f209473374cf02ad829 to your computer and use it in GitHub Desktop.
$ python script.py ps
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> localhost:9000}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 127.0.0.1:9001}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:9000
python experiment.py worker
[exit silently, not output]
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
import os
import json
import numpy as np
import functools
output_dir = "/tmp/iris_model"
# Data sets
IRIS_TRAINING = "/tmp/irisdata/iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
IRIS_TEST = "/tmp/irisdata/iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
def get_data():
if not os.path.exists(IRIS_TRAINING):
raw = urllib.urlopen(IRIS_TRAINING_URL).read()
with open(IRIS_TRAINING, "w") as f:
f.write(raw)
if not os.path.exists(IRIS_TEST):
raw = urllib.urlopen(IRIS_TEST_URL).read()
with open(IRIS_TEST, "w") as f:
f.write(raw)
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
return training_set, test_set
def get_input_fn(dataset):
x = tf.constant(dataset.data)
y = tf.constant(dataset.target)
return x,y
def get_estimator(output_dir, config):
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir=output_dir,
config=config)
return classifier
def create_experiment_fn(config):
def experiment_fn(output_dir):
training_set, test_set = get_data()
train_input_fn = functools.partial(
get_input_fn, dataset=training_set)
eval_input_fn = functools.partial(
get_input_fn, dataset=test_set)
return tf.contrib.learn.Experiment(
estimator=get_estimator(output_dir, config),
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=1000,
eval_steps=100,
continuous_eval_throttle_secs=15,
eval_delay_secs=10)
return experiment_fn
def main():
tf_config = {
"cluster": {
'ps': ['127.0.0.1:9000'],
'worker': ['127.0.0.1:9001']
}
}
if args.type == "worker":
tf_config["task"] = {'type': 'worker', 'index': 0}
else:
tf_config["task"] = {'type': 'ps', 'index': 0}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
config = run_config_lib.RunConfig()
os.environ['TF_CONFIG'] = json.dumps(tf_config)
config = run_config_lib.RunConfig()
learn_runner.run(
experiment_fn=create_experiment_fn(config),
output_dir=output_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("type", type=str)
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment