Created
March 13, 2020 17:22
-
-
Save rreece/c4cf7b1a02162a9bd6df68f1755dffa4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Main training script. | |
Call by: | |
python3 train.py --params=params.yaml | |
""" | |
import argparse | |
import numpy as np | |
import yaml | |
import tensorflow as tf | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
from data import pseudo_input_fn as input_fn | |
from model import model_fn | |
from utils import get_params | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--debug', | |
default=False, | |
action='store_true') | |
parser.add_argument( | |
'--params', | |
type=str, | |
default='params.yaml') | |
parser.add_argument( | |
'--xla', | |
default=False, | |
action='store_true') | |
args = parser.parse_args() | |
return args | |
def train(params): | |
""" | |
training function | |
""" | |
tparams = params['training'] # train params | |
config_proto = tf.ConfigProto() | |
off = rewriter_config_pb2.RewriterConfig.OFF | |
config_proto.graph_options.rewrite_options.memory_optimization = off | |
config = tf.estimator.RunConfig( | |
model_dir=tparams['out_dir'], | |
tf_random_seed=tparams['seed'], | |
save_summary_steps=tparams['save_summary_steps'], | |
save_checkpoints_steps=tparams['save_checkpoints_steps'], | |
keep_checkpoint_max=tparams['num_ckpts'], | |
session_config=config_proto, | |
) | |
classifier = tf.estimator.Estimator( | |
model_fn=model_fn, | |
model_dir=tparams['out_dir'], | |
params=params, | |
config=config, | |
) | |
classifier.train( | |
input_fn=input_fn, | |
max_steps=tparams['max_steps'], | |
) | |
def print_debug(params): | |
""" | |
debug printing | |
""" | |
ds = input_fn(params) | |
ds_iter = ds.make_initializable_iterator() | |
el = ds_iter.get_next() | |
with tf.Session() as sess: | |
sess.run(ds_iter.initializer) | |
sess.run(tf.tables_initializer()) | |
print('DEBUG: batch 0', flush=True) | |
print_dict(sess.run(el)) | |
print('DEBUG: batch 1', flush=True) | |
print_dict(sess.run(el)) | |
print('DEBUG: batch 2', flush=True) | |
print_dict(sess.run(el)) | |
def print_dict(d): | |
for key, val in d.items(): | |
print("'%s': " % key, val, flush=True) | |
if isinstance(val, list): | |
print('len(%s) = %i' % (key, len(val)), flush=True) | |
if isinstance(val, np.ndarray): | |
print('%s.shape = %s' % (key, val.shape), flush=True) | |
def save_xla(params): | |
from tensorflow.contrib.compiler.xla import compile | |
from tensorflow.tools.xla_extract import XlaExtract | |
ds = input_fn(params) | |
phs = get_placeholders_from_dataset(ds) | |
def _model_fn(*phs): | |
m, t, x, y = phs | |
features = { | |
'm': m, | |
't': t, | |
'x': x, | |
'y': y, | |
} | |
labels = dict() | |
mode = 'train' | |
spec = model_fn(features, labels, mode, params) | |
with tf.control_dependencies([spec.train_op]): | |
return spec.loss | |
print('Compiling to XLA...', flush=True) | |
(out,) = compile(_model_fn, inputs=phs) | |
print('Saving graph_def...') | |
graph_def_string = str(out.graph.as_graph_def(add_shapes=True)) | |
fn = 'graph_def.pbtxt' | |
with open(fn, 'w') as f: | |
f.write(graph_def_string) | |
print('%s written.' % fn) | |
print('Doing XlaExtract...') | |
hlo_mod = XlaExtract(out) | |
fn = 'xla.pbtxt' | |
with open(fn, 'w') as f: | |
f.write(str(hlo_mod)) | |
print('%s written.' % fn) | |
def get_placeholders_from_dataset(dataset_obj): | |
""" | |
Not sure why this isn't working: | |
results = [] | |
_dtypes = tf.data.get_output_types(dataset_obj) | |
_shapes = tf.data.get_output_shapes(dataset_obj) | |
for dtype, tshape in zip(_dtypes, _shapes): | |
print('dtype: ', dtype, ', tshape: ', tshape, flush=True) | |
results.append(tf.placeholder(dtype=dtype, shape=tshape.as_list())) | |
return results | |
""" | |
bs = 8 | |
phs = [ | |
tf.placeholder(dtype=tf.float32, name='m', shape=(bs,)), | |
tf.placeholder(dtype=tf.float32, name='t', shape=(bs,)), | |
tf.placeholder(dtype=tf.float32, name='x', shape=(bs, 40)), | |
tf.placeholder(dtype=tf.float32, name='y', shape=(bs, 40)), | |
] | |
return phs | |
if __name__ == '__main__': | |
tf.logging.set_verbosity(tf.logging.INFO) | |
args = parse_args() | |
params_file = args.params | |
if params_file: | |
params = get_params(params_file) | |
else: | |
params = get_params() | |
if args.debug: | |
print_debug(params) | |
elif args.xla: | |
save_xla(params) | |
else: | |
train(params) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment