Created
May 1, 2020 18:43
-
-
Save rreece/0a6e2c418ea81263b09a3071eb2b1c63 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
from model import model_fn | |
from data import input_fn | |
from utils import ( | |
DEFAULT_PARAMS_FILE, | |
get_params, | |
) | |
def parse_args(): | |
# Parse command line ars | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--params', | |
default=DEFAULT_PARAMS_FILE, | |
help='Path to .yaml file with model parameters', | |
) | |
parser.add_argument( | |
'--model_dir', | |
default=None, | |
help='Model directory', | |
) | |
parser.add_argument( | |
'--device', | |
default=None, | |
help='Force model to run on a specific device (e.g., --device /gpu:0)', | |
) | |
parser.add_argument( | |
'--debug', | |
default=False, | |
action='store_true', | |
) | |
args = parser.parse_args() | |
return args | |
def train(params): | |
""" | |
Training function | |
""" | |
tparams = params['training'] | |
config_proto = tf.compat.v1.ConfigProto() | |
off = rewriter_config_pb2.RewriterConfig.OFF | |
config_proto.graph_options.rewrite_options.memory_optimization = off | |
est_config = tf.estimator.RunConfig( | |
**tparams['runconfig_params'], | |
session_config=config_proto, | |
) | |
est = tf.estimator.Estimator( | |
model_fn=model_fn, | |
config=est_config, | |
params=params, | |
) | |
est.train( | |
input_fn, | |
**tparams['train_params'], | |
) | |
def run_debug(params): | |
"""Run debug printing.""" | |
ds = input_fn(params) | |
ds_iter = ds.make_initializable_iterator() | |
el = ds_iter.get_next() | |
with tf.Session() as sess: | |
sess.run(ds_iter.initializer) | |
sess.run(tf.tables_initializer()) | |
print('DEBUG: batch 1', flush=True) | |
print_obj(sess.run(el)) | |
print('DEBUG: batch 2', flush=True) | |
print_obj(sess.run(el)) | |
print('DEBUG: batch 3', flush=True) | |
print_obj(sess.run(el)) | |
def print_obj(obj, name=None): | |
if isinstance(obj, list): | |
print_list(obj, name) | |
if isinstance(obj, tuple): | |
print_tuple(obj, name) | |
elif isinstance(obj, dict): | |
print_dict(obj, name) | |
elif isinstance(obj, np.ndarray): | |
print_np(obj, name) | |
else: | |
print(type(obj), flush=True) | |
print(obj, flush=True) | |
def print_list(obj, name=None): | |
if not name: | |
name = 'list' | |
len_obj = len(obj) | |
print('List: len(%s) = %i' % (name, len_obj), flush=True) | |
if len_obj < 10: | |
for i_val, val in enumerate(obj): | |
_name = '%s[%i]' % (name, i_val) | |
print_obj(val, name=_name) | |
def print_tuple(obj, name=None): | |
print_list(obj, name='tuple') | |
def print_dict(obj, name=None): | |
if not name: | |
name = 'dict' | |
print("Dict: '%s'" % name, flush=True) | |
keys = list(obj.keys()) | |
keys.sort() | |
print("keys: ", keys, flush=True) | |
for key in keys: | |
val = obj[key] | |
print_obj(val, name=key) | |
def print_np(obj, name=None): | |
if not name: | |
name = 'ndarray' | |
print('ndarray: %s.shape = %s' % (name, obj.shape), flush=True) | |
def main(): | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) | |
args = parse_args() | |
params = get_params(args.params) | |
if args.model_dir: | |
params['training']['run_config']['model_dir'] = args.model_dir | |
if args.device: | |
device_fn = (lambda op: args.device) if args.device else None | |
params['training']['run_config']['device_fn'] = device_fn | |
if args.debug: | |
run_debug(params) | |
else: | |
train(params) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment