Created
March 7, 2019 14:02
-
-
Save bernhardschaefer/01905b0fe83615f79e2928a2a10b6f28 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
keys_to_remove = [ | |
'module.roi_heads.box.predictor.cls_score.weight', | |
'module.roi_heads.box.predictor.cls_score.bias', | |
'module.roi_heads.box.predictor.bbox_pred.weight', | |
'module.roi_heads.box.predictor.bbox_pred.bias', | |
'module.roi_heads.mask.predictor.mask_fcn_logits.weight', # mask | |
'module.roi_heads.mask.predictor.mask_fcn_logits.bias' # mask | |
] | |
def trim_maskrcnn_benchmark_model(model_path: str, trimmed_model_path: str): | |
state_dict = torch.load(model_path, map_location="cpu") | |
model = state_dict['model'] | |
for key in keys_to_remove: | |
if key in model: | |
del model[key] | |
print('key: {} is removed'.format(key)) | |
else: | |
print('key: {} is not present'.format(key)) | |
print("Also deleting optimizer, scheduler, and iteration entries") | |
del state_dict['optimizer'] | |
del state_dict['scheduler'] | |
del state_dict['iteration'] | |
torch.save(state_dict, trimmed_model_path) | |
print(f'saved to: {trimmed_model_path}') | |
# usage example: | |
#model_path = "../maskrcnn-benchmark/models/e2e_mask_rcnn_R_50_FPN_1x.pth" | |
#trimmed_model_path = "../maskrcnn-benchmark/models/e2e_mask_rcnn_R_50_FPN_1x_trimmed.pth" | |
#trim_maskrcnn_benchmark_model(model_path, trimmed_model_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment