Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created September 14, 2020 17:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save morganmcg1/e77f4e8f1f6925f7a9e0c1c7b6ac3f33 to your computer and use it in GitHub Desktop.
Save morganmcg1/e77f4e8f1f6925f7a9e0c1c7b6ac3f33 to your computer and use it in GitHub Desktop.
class BatchNormFP32(nn.BatchNorm2d):
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
def forward(self, x): return super().forward(x.float()) # CAST BatchNorm input to float
# SWAP OUT REGUALR BN FOR BatchNormFP32 IN YOUR MODEL
def swap_batch_norm(model, layer_type_old, layer_type_new, copy_data=True):
conversion_count = 0
#TODO : make sure device is correct
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
# recurse
model._modules[name] = swap_batch_norm(module, layer_type_old, layer_type_new)
if type(module) == layer_type_old:
nf = getattr(module, 'num_features')
eps = getattr(module, 'eps')
mom = getattr(module, 'momentum')
aff = getattr(module, 'affine')
track = getattr(module, 'track_running_stats')
layer_old = module
layer_new = layer_type_new(nf, eps=eps, momentum=mom,
affine=aff, track_running_stats=track).cuda()
if copy_data:
# COPY WEIGHTS AND BIASES IN CASE IT'S PRETRAINED OR WE'VE DONE
# SOME FANCY INITIALISATION
layer_new.weight.data = layer_old.weight.data
layer_new.bias.data = layer_old.bias.data
model._modules[name] = layer_new
return model
model = swap_batch_norm(model, nn.BatchNorm2d, BatchNormFP32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment