Skip to content

Instantly share code, notes, and snippets.

@dineshj1
Created August 30, 2016 22:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save dineshj1/839e08576d441944fd6f36ca6896453b to your computer and use it in GitHub Desktop.
Save dineshj1/839e08576d441944fd6f36ca6896453b to your computer and use it in GitHub Desktop.
Training on Pycaffe
import argparse
import time
start_time=time.time();
################## Argument Parsing #####################################
parser=argparse.ArgumentParser();
parser.add_argument('-s','--solver', default='', type=str); # if empty, solver is created, else read
parser.add_argument('-res', '--resume_from', default='', type=str); #if not empty, resumes training from given file
parser.add_argument('-ft', '--finetune_from', default='', type=str);
#parser.add_argument('-r','--rng_seed', default=242351, type=int); # not implemented
parser.add_argument('-d','--debug_mode', default=False, type=bool);
parser.add_argument('-p','--prefix', default='../clust_runs/', type=str);
parser.add_argument('--showfigs', default=False, type=bool);
parser.add_argument('--log_interval', default=2000, type=bool);
# TODO not implemented (pending net read/write)
#parser.add_argument('-t','--trn_data', default='./panocon_cls_trn.h5', type=str);
#parser.add_argument('-v','--val_data', default='./panocon_cls_val.h5', type=str);
#parser.add_argument('-v','--test_data', default='./panocon_cls_val.h5', type=str);
#parser.add_argument('-s','--split_partition', default=1, type=int);
#parser.add_argument('-lw','--loss_weight', default=1, type=float);
# Solver parameters taken over within Pycaffe
parser.add_argument('--solver_max_iter', default=10000, type=int);
parser.add_argument('--solver_display', default=20, type=int);
parser.add_argument('--solver_test_iter', default=20, type=int);
parser.add_argument('--solver_test_interval', default=200, type=int);
parser.add_argument('--sys_cmd', default='', type=str); # can be used to rsync etc.
# TODO Solver functions not implemented completely (pending solver read/write)
parser.add_argument('--solver_snapshot', default=1000, type=int);
parser.add_argument('--solver_snapshot_prefix', default='../clust_runs/caffe_snapshots/snap', type=str);
parser.add_argument('--solver_mode', default='GPU', type=str);
parser.add_argument('--solver_net', default='"../SUN360/clsnet_32_net.prototxt"', type=str);
parser.add_argument('--solver_base_lr', default=0.0001, type=float);
parser.add_argument('--solver_momentum', default=0.9, type=float);
parser.add_argument('--solver_weight_decay', default=5e-4, type=float);
parser.add_argument('--solver_lr_policy', default='"fixed"', type=str);
parser.add_argument('--caffe_pythonpath', default='/vision/vision_users/dineshj/caffe_vis/python/', type=str);
# early termination
parser.add_argument('-time','--max_time', default=float('inf'), type=float); # in minutes
parser.add_argument('-ET_w','--saturation_wait', default=0, type=int);
parser.add_argument('-ET_tar','--target_output', default='', type=str);
parser.add_argument('-ET_drn','--target_drn', default='h', type=str); # indicating that higher or lower is better
parser.add_argument('-ET_ov', '--overfit_margin', default=0, type=float); # indicating how close to target
parser.add_argument('-ET_perf', '--target_perfect', default=1.0, type=float); # indicating how close to target
args=parser.parse_args();
print(args)
#np.random.seed(args.rng_seed); #doesn't affect in any way at the moment (Caffe uses a different random seed)
#########################################################################
print "Importing necessary libraries"
import sys
import matplotlib as mpl
if not args.showfigs:
mpl.use('Agg');
import matplotlib.pyplot as plt
#caffe_pythonpath='/vision/vision_users/dineshj/caffe_vis/python/';
sys.path.insert(0, args.caffe_pythonpath)
import caffe
#import lmdb
#from pylab import *
#import pylab
from matplotlib.figure import Figure
import numpy as np
import cPickle as pk
import scipy.io as sio
import pdb
import os
print("--- Runtime: %.2f secs ---" % (time.time()-start_time))
def update_logs(all_op_names, train_op_names, test_op_names, train_ops, test_ops, args, best_test_iter, best_test_score, target_op):
plt.close('all');
print "Creating and saving plots"
num_all_outputs=len(all_op_names);
for opno in range(num_all_outputs):
fig, ax_array=plt.subplots(nrows=2, sharex=True);
if all_op_names[opno] in train_op_names:
curr_op=train_ops[opno];
ax_array[0].plot((np.arange(len(curr_op)))*args.solver_display, curr_op); ax_array[0].set_title('train %s' % all_op_names[opno]);
ax_array[0].axvline(x=best_test_iter, color='r');
if all_op_names[opno] in test_op_names:
curr_op=test_ops[opno];
ax_array[1].plot((np.arange(len(curr_op)))*args.solver_test_interval, curr_op); ax_array[1].set_title('test %s' % all_op_names[opno]);
ax_array[1].axvline(x=best_test_iter, color='r');
if opno==target_op:
y_range=ax_array[1].get_ylim();
y_ht=y_range[1]-y_range[0];
scale=np.round(np.log10(y_ht))<=1;
if scale:
prec_str="%%%df"%(scale+3)
print prec_str
else:
prec_str="%d";
ax_array[1].text(best_test_iter, best_test_score+y_ht*0.05, (prec_str%best_test_score).replace("-0", "-").lstrip("0") , color='r');
if args.showfigs:
print "Trying to show plots"
try:
plt.ion();
plt.show();
except Exception as e:
print e.__doc__
print e.message
print "Skipping showing figure. Saving directly."
fig_name_root="%s_%s" % (args.prefix, all_op_names[opno]);
print "Storing fig to %s(.png/.pkfig)" % fig_name_root
plt.savefig(fig_name_root+'.png');
pk.dump(fig, file(fig_name_root + '.pkfig', 'w'));
matfilename="%s.mat" % args.prefix;
print "Saving records to %s" % matfilename;
sys.stdout.flush();
sio.savemat(matfilename,
{
'train_ops':train_ops,
'test_ops':test_ops,
'train_op_names':train_op_names,
'test_op_names':test_op_names,
'all_op_names':all_op_names,
'best_test_score': best_test_score,
'best_test_iter': best_test_iter
}
);
if args.sys_cmd:
print "Running sys cmd: %s" % args.sys_cmd;
try:
os.system(args.sys_cmd);
except Exception as e:
print e.__doc__
print e.message
print "Cmd did not work."
#########################################################################
#print "Setting up network"
#from caffe import layers as L
#from caffe import params as P
#print "Beginning net definition"
#def net(lmdbname, batch_size):
# n = caffe.NetSpec()
# n.data, n.label = L.Data(
# batch_size=batch_size,
# backend=P.Data.LMDB,
# source=lmdbname,
# transform_param=dict(
# #mirror=True,
# #crop_size=227,
# #mean_file='/scratch/vision/dineshj/caffe2/data/ilsvrc12/imagenet_mean.binaryproto'
# ),
# ntop=2,
# )
# n.data
# return n.to_proto()
#with open('auto_train.prototxt', 'w') as f:
# f.write(str(mini_net(lmdbname, 64)))
#print("--- Runtime: %.2f secs ---" % (time.time()-start_time))
## Writing a solver
solver_file=args.solver;
if not solver_file:
print "Setting up solver"
solver_file="%s_solver.prototxt" % args.prefix;
print "Writing a solver"
solver_dict={};
if not args.solver_net[0]=='"':
args.solver_net='"'+ args.solver_net + '"';
if not args.solver_lr_policy[0]=='"':
args.solver_lr_policy='"'+ args.solver_lr_policy + '"';
if not args.solver_snapshot_prefix[0]=='"':
solver_dict['snapshot_prefix']=str('"'+args.solver_snapshot_prefix + '"');
solver_dict['net']=args.solver_net;
solver_dict['test_iter']=str(0);
solver_dict['test_interval']=str(int(args.solver_max_iter)*2);
solver_dict['base_lr']=str(args.solver_base_lr);
solver_dict['momentum']=str(args.solver_momentum);
solver_dict['weight_decay']=str(args.solver_weight_decay);
solver_dict['lr_policy']=str(args.solver_lr_policy);
solver_dict['display']=str(0);
solver_dict['max_iter']=str(args.solver_max_iter);
solver_dict['snapshot']=str(args.solver_snapshot);
solver_dict['solver_mode']=args.solver_mode;
with file(solver_file, 'w') as f:
for key in solver_dict:
f.write(key+':'+solver_dict[key]+'\n');
print("--- Runtime: %.2f secs ---" % (time.time()-start_time))
print "Loading solver"
#caffe.set_mode_cpu();
solver=caffe.SGDSolver(solver_file);
if args.resume_from:
print "Resuming from %s" %(args.resume_from)
solver.restore(args.resume_from);
elif args.finetune_from:
print "Finetuning %s" %(args.finetune_from)
solver.net.copy_from(args.finetune_from);
print("--- Runtime: %.2f secs ---" % (time.time()-start_time))
test_batch_sz=solver.test_nets[0].blobs.items()[0][1].shape[0];
it=0;
# get outputs automatically from solver.net.outputs
train_op_names=solver.net.outputs;
#num_outputs=len(train_op_names);
# get test outputs automatically from solver.test_nets[0].outputs
test_op_names=solver.test_nets[0].outputs;
#num_test_outputs=len(test_op_names);
all_op_names= list(set(train_op_names) | set(test_op_names));
num_all_outputs=len(all_op_names);
train_ops=np.zeros((num_all_outputs, max(args.solver_max_iter/args.solver_display, 1)));
test_ops=np.zeros((num_all_outputs, max(args.solver_max_iter/args.solver_test_interval,1)));
# Automatically determine the target variable to determine early termination based on, etc.
if args.target_output:
try:
target_op=all_op_names.index(args.target_output);
except Exception as e:
print e.__doc__
print e.message
print "Could not find output %s. Setting to empty." % (args.target_output);
args.target_output='';
if not args.target_output:
# TODO include a regex search for outputs starting with "target_"
if 'accuracy' in test_op_names:
target_op=all_op_names.index('accuracy');
args.target_drn='h';
elif 'cls_accuracy' in test_op_names:
target_op=all_op_names.index('cls_accuracy');
args.target_drn='h';
elif 'acc' in test_op_names:
target_op=all_op_names.index('acc');
args.target_drn='h';
elif 'cls_acc' in test_op_names:
target_op=all_op_names.index('cls_acc');
args.target_drn='h';
else:
target_op=0; # setting randomly to the first output
print "Setting target output to #%d:%s (dir: %s)" % (target_op, all_op_names[target_op], args.target_drn);
ET_on=False
if args.saturation_wait>0 or args.overfit_margin>0:
ET_on=True;
print "EARLY TERMINATION ON";
args.solver_test_interval=np.round(args.solver_test_interval/args.solver_display)*args.solver_display;
args.log_interval=np.round(args.log_interval/args.solver_display)*args.solver_display;
best_test_score=np.NaN;
best_test_iter=0;
timeout_flag=False;
termination_flag=False;
if args.overfit_margin>0:
overfit_clear=False;
else:
overfit_clear=True;
if args.saturation_wait>0:
saturate_clear=False;
else:
saturate_clear=True;
overfit_window=5;
solver.net.forward();
solver.test_nets[0].forward();
if args.debug_mode:
pdb.set_trace()
print "Beginning iterations"
while it < args.solver_max_iter and not termination_flag:
sys.stdout.flush()
for opno in range(num_all_outputs):
if all_op_names[opno] in train_op_names:
train_ops[opno, it/args.solver_display]=solver.net.blobs[all_op_names[opno]].data;
print 'Py-iteration', it, 'training outputs ...'
print train_ops[:, it/args.solver_display]
if not overfit_clear:
if it/args.solver_display>=overfit_window:
train_score_running_avg = np.mean(train_ops[target_op, it/args.solver_display:it/args.solver_display-overfit_window:-1]);
if np.abs(train_score_running_avg-args.target_perfect) < args.overfit_margin:
overfit_clear=True;
print "Running avg: %f" % train_score_running_avg;
print "Overfit target achieved";
runtime = (time.time()-start_time)/60;
if runtime>args.max_time:
print "Ran out of time. Finishing up early."
timeout_flag=True;
if it % args.solver_test_interval==0:
print '====================================='
print("--- runtime: %.2f / %.2f mins ---" % (runtime, args.max_time));
print 'Py-iteration', it, 'testing outputs ...'
correct=0;
loss_sum=0;
test_op_sum=np.zeros(num_all_outputs);
for test_it in range(args.solver_test_iter):
solver.test_nets[0].forward()
for opno in range(num_all_outputs):
if all_op_names[opno] in test_op_names:
test_op_sum[opno]+=solver.test_nets[0].blobs[all_op_names[opno]].data;
#correct+=sum(solver.test_nets[0].blobs['cls_ip2'].data.argmax(1) == solver.test_nets[0].blobs['cls_label'].data);
test_ops[:, it/args.solver_test_interval]=test_op_sum[:]/args.solver_test_iter;
#test_acc[it/args.solver_test_interval] = correct/(args.solver_test_iter*test_batch_sz);
print test_ops[:, it/args.solver_test_interval]
# implement early termination
if args.debug_mode:
pdb.set_trace()
if args.target_drn=='h':
if not best_test_score > test_ops[target_op, it/args.solver_test_interval]: # i.e. current score is highest
best_test_score=test_ops[target_op, it/args.solver_test_interval];
best_test_iter=it;
print "%s (op #%d/%d) improved!" % (all_op_names[target_op], target_op+1, len(all_op_names));
print(args.solver_snapshot_prefix + '_bestweights')
solver.net.save(args.solver_snapshot_prefix + '_bestweights'); # update saved best model
else:
print "best %s (op #%d/%d) so far: %f (at iter %d)" %(all_op_names[target_op], target_op+1, len(all_op_names), best_test_score, best_test_iter);
elif args.target_drn=='l':
if not best_test_score < test_ops[target_op, it/args.solver_test_interval]: # i.e. current score is highest
best_test_score=test_ops[target_op, it/args.solver_test_interval];
best_test_iter=it;
print "%s (op #%d/%d) improved!" % (all_op_names[target_op], target_op+1, len(all_op_names));
solver.net.save(args.solver_snapshot_prefix + '_bestweights'); # update saved best model
else:
print "best %s (op #%d/%d) so far: %f (at iter %d)" %(all_op_names[target_op], target_op+1, len(all_op_names), best_test_score, best_test_iter);
else:
raise NameError('Unknown target direction (target_drn) %s' % args.target_drn);
if overfit_clear and not saturate_clear:
if it - best_test_iter > args.saturation_wait: # time to quit!
saturate_clear=True;
print "Test performance saturation target cleared"
print '====================================='
if it % args.log_interval==0:
update_logs(all_op_names, train_op_names, test_op_names, train_ops, test_ops, args, best_test_iter, best_test_score, target_op);
termination_flag = timeout_flag or (ET_on and overfit_clear and saturate_clear);
if not termination_flag:
solver.step(args.solver_display);
it=it+args.solver_display;
else:
print "Triggering early termination";
if termination_flag:
train_ops=train_ops[:,:it/args.solver_display+1];
test_ops=test_ops[:,:it/args.solver_test_interval+1];
update_logs(all_op_names, train_op_names, test_op_names, train_ops, test_ops, args, best_test_iter, best_test_score, target_op);
elapsed_s=time.time()-start_time;
print("--- %.2f secs (%.2f mins) ---" % (elapsed_s, elapsed_s/60))
sys.stdout.flush();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment