Created
April 10, 2020 16:14
-
-
Save interogativ/5a40de51796901276273d651cb17be5a to your computer and use it in GitHub Desktop.
Architecture Runner as described in FastAI post
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Numpy and pandas by default assume a narrow screen - this fixes that | |
from fastai2.vision.all import * | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
pd.set_option('display.max_columns',999) | |
np.set_printoptions(linewidth=200) | |
torch.set_printoptions(linewidth=200) | |
## This is a copy of Jeremy's function to recover from CUDA out of memory crashes | |
class gpu_mem_restore_ctx(): | |
" context manager to reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted" | |
def __enter__(self): return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if not exc_val: return True | |
traceback.clear_frames(exc_tb) | |
raise exc_type(exc_val).with_traceback(exc_tb) from None | |
## | |
## specify architectures as follows | |
#def doTrain(name, arch, bs, epochs=5): | |
#architectures = OrderedDict() | |
#architectures['resnet34'] = (resnet34, 64, 5) | |
#architectures['resnet50'] = (resnet34, 64, 5) | |
##architectures['squeezenet1_0'] = (squeezenet1_0, 64, 5) | |
#architectures['squeezenet1_1'] = (squeezenet1_1, 64, 5) | |
#architectures['vgg16_bn'] = (vgg16_bn, 32, 1) | |
# put your trainer into a function called doTrain that has the form | |
# doTrain(name, arch, bs, epochs) | |
def architectureRunner(architectures, doTrain): | |
results = OrderedDict() | |
bs = None | |
for entry in architectures: | |
name = entry | |
arch = architectures[entry][0] | |
bs = architectures[entry][1] | |
epochs = architectures[entry][2] | |
stats=None | |
doStuff = True | |
while doStuff: | |
stats = doTrain(name,arch,bs,epochs) | |
if stats: | |
stats += bs | |
stats += epochs | |
results[name] = stats | |
print(f'Arch: {name} Loss: {stats[1]:.03f} Accuracy: {1.0-stats[2]:.03f}') | |
doStuff = False | |
elif bs == 16: | |
results[name] = [1.0,1,0,0.0] | |
print(f'Arch: {name} failed due to running out of memory') | |
doStuff = False | |
else: | |
gc.collect() | |
newbs = int(bs/2) | |
print(f'Arch: {name} failed at {bs} Changing Batch size to {newbs}') | |
bs = newbs | |
doStuff = True | |
#show which was best | |
# sort by accuracy | |
print('Results:') | |
for key,value in sorted(results.items(), key = lambda result: result[1][2]): | |
name = key | |
loss = value[1] | |
accuracy = 1.0-value[2] | |
bs = value[3] | |
print(f'Arch: {name} bs={bs} Loss: {loss:.03f} Accuracy: {accuracy:.03f}') | |
return results | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment