Skip to content

Instantly share code, notes, and snippets.

@mkolod
Last active October 21, 2020 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mkolod/82d92387f9184c020401019c1d382e30 to your computer and use it in GitHub Desktop.
Save mkolod/82d92387f9184c020401019c1d382e30 to your computer and use it in GitHub Desktop.
# NOTE: The network here is not means to make any sense. It's just for measuring perf impact.
import torch
import torch.nn.functional as F
from time import time
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
fcs = [torch.nn.Linear(10, 100)] + [torch.nn.Linear(100, 100) for _ in range(20)]
self.fcs = torch.nn.Sequential(*fcs)
def forward(self, x):
return self.fcs(x)
class Combine(torch.nn.Module):
def __init__(self, fork=False):
super(Combine, self).__init__()
self.fork = fork
self.branch0 = Net()
self.branch1 = Net()
@torch.jit.export
def forward_forked(self, x):
fut0 = torch.jit.fork(self.branch0, x)
fut1 = torch.jit.fork(self.branch1, x)
return torch.jit.wait(fut0) + torch.jit.wait(fut1)
@torch.jit.export
def forward_reg(self, x):
return self.branch0(x) + self.branch1(x)
def forward(self, x):
return self.forward_forked(x) if self.fork else self.forward_reg(x)
if __name__ == '__main__':
combine_reg = torch.jit.script(Combine(fork=False).cuda().eval())
combine_forked = torch.jit.script(Combine(fork=True).cuda().eval())
x = torch.randn(10, 10).cuda()
for _ in range(50):
res0 = combine_reg(x)
res1 = combine_forked(x)
torch.cuda.synchronize()
start_reg = time()
res0 = combine_reg(x)
torch.cuda.synchronize()
end_reg = time()
res1 = combine_forked(x)
torch.cuda.synchronize()
end_forked = time()
print(f"exec time of non-forked: {(end_reg - start_reg) * 1000:.3f} ms")
print(f"exec time of forked: {(end_forked - end_reg) * 1000:.3f} ms")
@mkolod
Copy link
Author

mkolod commented Oct 12, 2020

Tested on:

  • Intel(R) Xeon(R) Gold 6136 CPU @ 3.00GHz
  • NVIDIA Titan RTX, clocks locked at 1,620 MHz
  • PyTorch 1.6

Result:

python torch_fork.py
exec time of non-forked: 2.935 ms
exec time of forked: 2.004 ms

@mkolod
Copy link
Author

mkolod commented Oct 12, 2020

Before (left) and after forking (right) in nvvp

Screen Shot 2020-10-12 at 9 03 35 AM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment