Skip to content

Instantly share code, notes, and snippets.

@wookim3
Created October 19, 2020 14:28
Show Gist options
  • Save wookim3/5edfc246781dfcff6c554a25c56e6a5b to your computer and use it in GitHub Desktop.
Save wookim3/5edfc246781dfcff6c554a25c56e6a5b to your computer and use it in GitHub Desktop.
def _register_comm_hook(
self,
state: object,
hook: callable):
def fp16_compress_hook(
process_group: object,
bucket: dist._GradBucket):
compressed_tensor =
bucket.get_tensors()[0].to(torch.float16)
fut = dist.all_reduce(
compressed_tensor,
group=process_group,
async_op=True
).get_future()
def decompress(fut):
return [fut.value()[0]
.to(torch.float32)
.div_(world_size)]
return fut.then(decompress)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment