Skip to content

Instantly share code, notes, and snippets.

@rreece
Created May 1, 2020 18:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rreece/0a6e2c418ea81263b09a3071eb2b1c63 to your computer and use it in GitHub Desktop.
Save rreece/0a6e2c418ea81263b09a3071eb2b1c63 to your computer and use it in GitHub Desktop.
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from model import model_fn
from data import input_fn
from utils import (
DEFAULT_PARAMS_FILE,
get_params,
)
def parse_args():
# Parse command line ars
parser = argparse.ArgumentParser()
parser.add_argument(
'--params',
default=DEFAULT_PARAMS_FILE,
help='Path to .yaml file with model parameters',
)
parser.add_argument(
'--model_dir',
default=None,
help='Model directory',
)
parser.add_argument(
'--device',
default=None,
help='Force model to run on a specific device (e.g., --device /gpu:0)',
)
parser.add_argument(
'--debug',
default=False,
action='store_true',
)
args = parser.parse_args()
return args
def train(params):
"""
Training function
"""
tparams = params['training']
config_proto = tf.compat.v1.ConfigProto()
off = rewriter_config_pb2.RewriterConfig.OFF
config_proto.graph_options.rewrite_options.memory_optimization = off
est_config = tf.estimator.RunConfig(
**tparams['runconfig_params'],
session_config=config_proto,
)
est = tf.estimator.Estimator(
model_fn=model_fn,
config=est_config,
params=params,
)
est.train(
input_fn,
**tparams['train_params'],
)
def run_debug(params):
"""Run debug printing."""
ds = input_fn(params)
ds_iter = ds.make_initializable_iterator()
el = ds_iter.get_next()
with tf.Session() as sess:
sess.run(ds_iter.initializer)
sess.run(tf.tables_initializer())
print('DEBUG: batch 1', flush=True)
print_obj(sess.run(el))
print('DEBUG: batch 2', flush=True)
print_obj(sess.run(el))
print('DEBUG: batch 3', flush=True)
print_obj(sess.run(el))
def print_obj(obj, name=None):
if isinstance(obj, list):
print_list(obj, name)
if isinstance(obj, tuple):
print_tuple(obj, name)
elif isinstance(obj, dict):
print_dict(obj, name)
elif isinstance(obj, np.ndarray):
print_np(obj, name)
else:
print(type(obj), flush=True)
print(obj, flush=True)
def print_list(obj, name=None):
if not name:
name = 'list'
len_obj = len(obj)
print('List: len(%s) = %i' % (name, len_obj), flush=True)
if len_obj < 10:
for i_val, val in enumerate(obj):
_name = '%s[%i]' % (name, i_val)
print_obj(val, name=_name)
def print_tuple(obj, name=None):
print_list(obj, name='tuple')
def print_dict(obj, name=None):
if not name:
name = 'dict'
print("Dict: '%s'" % name, flush=True)
keys = list(obj.keys())
keys.sort()
print("keys: ", keys, flush=True)
for key in keys:
val = obj[key]
print_obj(val, name=key)
def print_np(obj, name=None):
if not name:
name = 'ndarray'
print('ndarray: %s.shape = %s' % (name, obj.shape), flush=True)
def main():
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
args = parse_args()
params = get_params(args.params)
if args.model_dir:
params['training']['run_config']['model_dir'] = args.model_dir
if args.device:
device_fn = (lambda op: args.device) if args.device else None
params['training']['run_config']['device_fn'] = device_fn
if args.debug:
run_debug(params)
else:
train(params)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment