Skip to content

Instantly share code, notes, and snippets.

@rreece
Created March 13, 2020 17:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rreece/c4cf7b1a02162a9bd6df68f1755dffa4 to your computer and use it in GitHub Desktop.
Save rreece/c4cf7b1a02162a9bd6df68f1755dffa4 to your computer and use it in GitHub Desktop.
"""
Main training script.
Call by:
python3 train.py --params=params.yaml
"""
import argparse
import numpy as np
import yaml
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from data import pseudo_input_fn as input_fn
from model import model_fn
from utils import get_params
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--debug',
default=False,
action='store_true')
parser.add_argument(
'--params',
type=str,
default='params.yaml')
parser.add_argument(
'--xla',
default=False,
action='store_true')
args = parser.parse_args()
return args
def train(params):
"""
training function
"""
tparams = params['training'] # train params
config_proto = tf.ConfigProto()
off = rewriter_config_pb2.RewriterConfig.OFF
config_proto.graph_options.rewrite_options.memory_optimization = off
config = tf.estimator.RunConfig(
model_dir=tparams['out_dir'],
tf_random_seed=tparams['seed'],
save_summary_steps=tparams['save_summary_steps'],
save_checkpoints_steps=tparams['save_checkpoints_steps'],
keep_checkpoint_max=tparams['num_ckpts'],
session_config=config_proto,
)
classifier = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=tparams['out_dir'],
params=params,
config=config,
)
classifier.train(
input_fn=input_fn,
max_steps=tparams['max_steps'],
)
def print_debug(params):
"""
debug printing
"""
ds = input_fn(params)
ds_iter = ds.make_initializable_iterator()
el = ds_iter.get_next()
with tf.Session() as sess:
sess.run(ds_iter.initializer)
sess.run(tf.tables_initializer())
print('DEBUG: batch 0', flush=True)
print_dict(sess.run(el))
print('DEBUG: batch 1', flush=True)
print_dict(sess.run(el))
print('DEBUG: batch 2', flush=True)
print_dict(sess.run(el))
def print_dict(d):
for key, val in d.items():
print("'%s': " % key, val, flush=True)
if isinstance(val, list):
print('len(%s) = %i' % (key, len(val)), flush=True)
if isinstance(val, np.ndarray):
print('%s.shape = %s' % (key, val.shape), flush=True)
def save_xla(params):
from tensorflow.contrib.compiler.xla import compile
from tensorflow.tools.xla_extract import XlaExtract
ds = input_fn(params)
phs = get_placeholders_from_dataset(ds)
def _model_fn(*phs):
m, t, x, y = phs
features = {
'm': m,
't': t,
'x': x,
'y': y,
}
labels = dict()
mode = 'train'
spec = model_fn(features, labels, mode, params)
with tf.control_dependencies([spec.train_op]):
return spec.loss
print('Compiling to XLA...', flush=True)
(out,) = compile(_model_fn, inputs=phs)
print('Saving graph_def...')
graph_def_string = str(out.graph.as_graph_def(add_shapes=True))
fn = 'graph_def.pbtxt'
with open(fn, 'w') as f:
f.write(graph_def_string)
print('%s written.' % fn)
print('Doing XlaExtract...')
hlo_mod = XlaExtract(out)
fn = 'xla.pbtxt'
with open(fn, 'w') as f:
f.write(str(hlo_mod))
print('%s written.' % fn)
def get_placeholders_from_dataset(dataset_obj):
"""
Not sure why this isn't working:
results = []
_dtypes = tf.data.get_output_types(dataset_obj)
_shapes = tf.data.get_output_shapes(dataset_obj)
for dtype, tshape in zip(_dtypes, _shapes):
print('dtype: ', dtype, ', tshape: ', tshape, flush=True)
results.append(tf.placeholder(dtype=dtype, shape=tshape.as_list()))
return results
"""
bs = 8
phs = [
tf.placeholder(dtype=tf.float32, name='m', shape=(bs,)),
tf.placeholder(dtype=tf.float32, name='t', shape=(bs,)),
tf.placeholder(dtype=tf.float32, name='x', shape=(bs, 40)),
tf.placeholder(dtype=tf.float32, name='y', shape=(bs, 40)),
]
return phs
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
params_file = args.params
if params_file:
params = get_params(params_file)
else:
params = get_params()
if args.debug:
print_debug(params)
elif args.xla:
save_xla(params)
else:
train(params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment