Skip to content

Instantly share code, notes, and snippets.

@Cadene
Created August 8, 2019 07:20
Show Gist options
  • Save Cadene/a00eba2db81f44b4492dc8941f6abe45 to your computer and use it in GitHub Desktop.
Save Cadene/a00eba2db81f44b4492dc8941f6abe45 to your computer and use it in GitHub Desktop.
Bug indexing
import time
import torch
import itertools
def make_pairs_ids(nregion, bsize):
pairs_ids = []
for batch_id in range(bsize):
pairs_id = torch.LongTensor([
(batch_id,i,j) for i,j in \
itertools.product(range(nregion),repeat=2)])
pairs_ids.append(pairs_id)
out = torch.cat(pairs_ids).contiguous()
return out
if __name__ == '__main__':
niter=10
bsize=32
nregion=36
dimh=2048
module = torch.nn.Linear(dimh, dimh)
mm = torch.randn(bsize, nregion, dimh)
pairs_ids = make_pairs_ids(nregion, bsize)
module.cuda()
mm = mm.cuda()
pairs_ids = pairs_ids.cuda()
mm = torch.autograd.Variable(mm, requires_grad=True)
t = time.time()
torch.cuda.synchronize()
for i in range(niter):
pair_mm = mm[pairs_ids[:,0][:,None], pairs_ids[:,1:]]
outfusion = pair_mm[:,0,:] - pair_mm[:,1,:]
out = module(outfusion)
out.sum().backward()
torch.cuda.synchronize()
print(time.time() - t)
import time
import torch
import itertools
def make_pairs_ids(nregion, bsize):
pairs_ids = []
for batch_id in range(bsize):
pairs_id = torch.tensor([
(batch_id,i,j) for i,j in \
itertools.product(range(nregion),repeat=2)],
requires_grad=False,
dtype=torch.long)
pairs_ids.append(pairs_id)
out = torch.cat(pairs_ids).contiguous()
return out
if __name__ == '__main__':
niter=10
bsize=32
nregion=36
dimh=2048
module = torch.nn.Linear(dimh, dimh)
mm = torch.randn(bsize, nregion, dimh, requires_grad=True)
pairs_ids = make_pairs_ids(nregion, bsize)
module.cuda()
mm = mm.cuda()
pairs_ids = pairs_ids.cuda()
t = time.time()
torch.cuda.synchronize()
for i in range(niter):
pair_mm = mm[pairs_ids[:,0][:,None], pairs_ids[:,1:]]
pair_mm.detach_()
# non symetrical fusion
outfusion = pair_mm[:,0,:] - pair_mm[:,1,:]
out = module(outfusion)
out.sum().backward()
torch.cuda.synchronize()
print(time.time() - t)
import time
import torch
import itertools
def make_pairs_ids(nregion, bsize):
pairs_ids = []
for batch_id in range(bsize):
pairs_id = torch.tensor([
(batch_id,i,j) for i,j in \
itertools.product(range(nregion),repeat=2)],
requires_grad=False,
dtype=torch.long)
pairs_ids.append(pairs_id)
out = torch.cat(pairs_ids).contiguous()
return out
if __name__ == '__main__':
niter=10
bsize=32
nregion=36
dimh=2048
module = torch.nn.Linear(dimh, dimh)
mm = torch.randn(bsize, nregion, dimh, requires_grad=True)
pairs_ids = make_pairs_ids(nregion, bsize)
module.cuda()
mm = mm.cuda()
pairs_ids = pairs_ids.cuda()
t = time.time()
torch.cuda.synchronize()
for i in range(niter):
with torch.no_grad():
pair_mm = mm[pairs_ids[:,0][:,None], pairs_ids[:,1:]]
# non symetrical fusion
outfusion = pair_mm[:,0,:] - pair_mm[:,1,:]
out = module(outfusion)
out.sum().backward()
torch.cuda.synchronize()
print(time.time() - t)
import time
import torch
import itertools
def make_pairs_ids(nregion, bsize):
pairs_ids = []
for batch_id in range(bsize):
pairs_id = torch.tensor([
(batch_id,i,j) for i,j in \
itertools.product(range(nregion),repeat=2)],
requires_grad=False,
dtype=torch.long)
pairs_ids.append(pairs_id)
out = torch.cat(pairs_ids).contiguous()
return out
if __name__ == '__main__':
niter=10
bsize=32
nregion=36
dimh=2048
module = torch.nn.Linear(dimh, dimh)
mm = torch.randn(bsize, nregion, dimh, requires_grad=True)
pairs_ids = make_pairs_ids(nregion, bsize)
module.cuda()
mm = mm.cuda()
pairs_ids = pairs_ids.cuda()
t = time.time()
torch.cuda.synchronize()
for i in range(niter):
pair_mm = mm[pairs_ids[:,0][:,None], pairs_ids[:,1:]]
# non symetrical fusion
outfusion = pair_mm[:,0,:] - pair_mm[:,1,:]
out = module(outfusion)
out.sum().backward()
torch.cuda.synchronize()
print(time.time() - t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment