Skip to content

Instantly share code, notes, and snippets.

@awni
Created August 6, 2017 18:08
Show Gist options
  • Save awni/9989dd31642d42405903dec8ab91d1f0 to your computer and use it in GitHub Desktop.
Save awni/9989dd31642d42405903dec8ab91d1f0 to your computer and use it in GitHub Desktop.
Test PyTorch Attentional performance
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
def attend_bmm(eh, dhx):
dhx = dhx.unsqueeze(1)
pax = torch.bmm(eh, dhx.transpose(1,2)).squeeze(dim=2)
ax = nn.functional.softmax(pax)
sx = ax.unsqueeze(2)
sx = torch.bmm(eh.transpose(1,2), sx)
return sx.squeeze(dim=2), ax
def attend_bx(eh, dhx):
pax = torch.sum(eh * dhx.unsqueeze(dim=1), dim=2)
ax = nn.functional.softmax(pax)
sx = torch.sum(eh * ax.unsqueeze(dim=2), dim=1)
return sx, ax
def perf(eh, dhx):
n_runs = 200
start = time.time()
for _ in range(n_runs):
attend_bmm(eh, dhx)
tot_bmm = time.time() - start
start = time.time()
for _ in range(n_runs):
attend_bx(eh, dhx)
tot_bx = time.time() - start
print("BMM {:.3f} (s) -- BX {:.3f} (s)".format(tot_bmm, tot_bx))
if __name__ == "__main__":
eh = Variable(torch.randn(16, 200, 512))
dhx = Variable(torch.randn(16, 512))
# warm-up
s1, a1 = attend_bmm(eh, dhx)
s2, a2 = attend_bx(eh, dhx)
print("CPU TIMES")
perf(eh, dhx)
print("GPU TIMES")
perf(eh.cuda(), dhx.cuda())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment