|
import tensorflow as tf |
|
|
|
from tensorflow.contrib.learn.python.learn import learn_runner |
|
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib |
|
|
|
import os |
|
import json |
|
import numpy as np |
|
import functools |
|
|
|
output_dir = "/tmp/iris_model" |
|
|
|
# Data sets |
|
IRIS_TRAINING = "/tmp/irisdata/iris_training.csv" |
|
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" |
|
|
|
IRIS_TEST = "/tmp/irisdata/iris_test.csv" |
|
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" |
|
|
|
def get_data(): |
|
if not os.path.exists(IRIS_TRAINING): |
|
raw = urllib.urlopen(IRIS_TRAINING_URL).read() |
|
with open(IRIS_TRAINING, "w") as f: |
|
f.write(raw) |
|
|
|
if not os.path.exists(IRIS_TEST): |
|
raw = urllib.urlopen(IRIS_TEST_URL).read() |
|
with open(IRIS_TEST, "w") as f: |
|
f.write(raw) |
|
|
|
# Load datasets. |
|
training_set = tf.contrib.learn.datasets.base.load_csv_with_header( |
|
filename=IRIS_TRAINING, |
|
target_dtype=np.int, |
|
features_dtype=np.float32) |
|
test_set = tf.contrib.learn.datasets.base.load_csv_with_header( |
|
filename=IRIS_TEST, |
|
target_dtype=np.int, |
|
features_dtype=np.float32) |
|
return training_set, test_set |
|
|
|
|
|
def get_input_fn(dataset): |
|
x = tf.constant(dataset.data) |
|
y = tf.constant(dataset.target) |
|
return x,y |
|
|
|
def get_estimator(output_dir, config): |
|
# Specify that all features have real-value data |
|
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] |
|
|
|
# Build 3 layer DNN with 10, 20, 10 units respectively. |
|
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, |
|
hidden_units=[10, 20, 10], |
|
n_classes=3, |
|
model_dir=output_dir, |
|
config=config) |
|
return classifier |
|
|
|
def create_experiment_fn(config): |
|
def experiment_fn(output_dir): |
|
training_set, test_set = get_data() |
|
train_input_fn = functools.partial( |
|
get_input_fn, dataset=training_set) |
|
eval_input_fn = functools.partial( |
|
get_input_fn, dataset=test_set) |
|
|
|
return tf.contrib.learn.Experiment( |
|
estimator=get_estimator(output_dir, config), |
|
train_input_fn=train_input_fn, |
|
eval_input_fn=eval_input_fn, |
|
train_steps=1000, |
|
eval_steps=100, |
|
continuous_eval_throttle_secs=15, |
|
eval_delay_secs=10) |
|
return experiment_fn |
|
|
|
def main(): |
|
tf_config = { |
|
"cluster": { |
|
'ps': ['127.0.0.1:9000'], |
|
'worker': ['127.0.0.1:9001'] |
|
} |
|
} |
|
if args.type == "worker": |
|
tf_config["task"] = {'type': 'worker', 'index': 0} |
|
else: |
|
tf_config["task"] = {'type': 'ps', 'index': 0} |
|
|
|
os.environ['TF_CONFIG'] = json.dumps(tf_config) |
|
config = run_config_lib.RunConfig() |
|
os.environ['TF_CONFIG'] = json.dumps(tf_config) |
|
config = run_config_lib.RunConfig() |
|
learn_runner.run( |
|
experiment_fn=create_experiment_fn(config), |
|
output_dir=output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("type", type=str) |
|
|
|
args = parser.parse_args() |
|
main(args) |