Skip to content

Instantly share code, notes, and snippets.

@interogativ
Created April 10, 2020 16:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save interogativ/5a40de51796901276273d651cb17be5a to your computer and use it in GitHub Desktop.
Save interogativ/5a40de51796901276273d651cb17be5a to your computer and use it in GitHub Desktop.
Architecture Runner as described in FastAI post
# Numpy and pandas by default assume a narrow screen - this fixes that
from fastai2.vision.all import *
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
pd.set_option('display.max_columns',999)
np.set_printoptions(linewidth=200)
torch.set_printoptions(linewidth=200)
## This is a copy of Jeremy's function to recover from CUDA out of memory crashes
class gpu_mem_restore_ctx():
" context manager to reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_val: return True
traceback.clear_frames(exc_tb)
raise exc_type(exc_val).with_traceback(exc_tb) from None
##
## specify architectures as follows
#def doTrain(name, arch, bs, epochs=5):
#architectures = OrderedDict()
#architectures['resnet34'] = (resnet34, 64, 5)
#architectures['resnet50'] = (resnet34, 64, 5)
##architectures['squeezenet1_0'] = (squeezenet1_0, 64, 5)
#architectures['squeezenet1_1'] = (squeezenet1_1, 64, 5)
#architectures['vgg16_bn'] = (vgg16_bn, 32, 1)
# put your trainer into a function called doTrain that has the form
# doTrain(name, arch, bs, epochs)
def architectureRunner(architectures, doTrain):
results = OrderedDict()
bs = None
for entry in architectures:
name = entry
arch = architectures[entry][0]
bs = architectures[entry][1]
epochs = architectures[entry][2]
stats=None
doStuff = True
while doStuff:
stats = doTrain(name,arch,bs,epochs)
if stats:
stats += bs
stats += epochs
results[name] = stats
print(f'Arch: {name} Loss: {stats[1]:.03f} Accuracy: {1.0-stats[2]:.03f}')
doStuff = False
elif bs == 16:
results[name] = [1.0,1,0,0.0]
print(f'Arch: {name} failed due to running out of memory')
doStuff = False
else:
gc.collect()
newbs = int(bs/2)
print(f'Arch: {name} failed at {bs} Changing Batch size to {newbs}')
bs = newbs
doStuff = True
#show which was best
# sort by accuracy
print('Results:')
for key,value in sorted(results.items(), key = lambda result: result[1][2]):
name = key
loss = value[1]
accuracy = 1.0-value[2]
bs = value[3]
print(f'Arch: {name} bs={bs} Loss: {loss:.03f} Accuracy: {accuracy:.03f}')
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment