Skip to content

Instantly share code, notes, and snippets.

@mattiaspaul
Created April 25, 2023 19:19
Show Gist options
  • Save mattiaspaul/b63cd65c9afa4290b316d9297e19ca03 to your computer and use it in GitHub Desktop.
Save mattiaspaul/b63cd65c9afa4290b316d9297e19ca03 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from tqdm import trange
passed = torch.zeros(50)
for i in trange(50):
pad = tuple(list(torch.randint(3,(3,))))
dilation = tuple(list(torch.randint(2,(3,))+1))
stride = tuple(list(torch.randint(2,(3,))+1))
kernel = tuple(list(torch.randint(4,(3,))+1))
groups = 4
bias = True
if(i%4==1):
groups = 1; bias = False
if(i%4==2):
groups = 64; bias = False
if(i%4==3):
groups = 1; bias = True
A = torch.ones(2,64,16,16,16)
A_mps = A.clone().to('mps')
conv3d = nn.Conv3d(64,64,kernel,groups=groups,padding=pad,\
stride=stride,dilation=dilation,bias=bias)
conv3d_mps = copy.deepcopy(conv3d).to('mps')
A_mps.requires_grad = True
A.requires_grad = True
B = conv3d(A)
B_mps = conv3d_mps(A_mps)
loss = B.pow(2).mul(.5).mean()
loss.backward()
loss_mps = B_mps.pow(2).mul(.5).mean()
loss_mps.backward()
passed[i] = float(torch.allclose(B,B_mps.cpu().data,1e-02, 1e-03))
passed[i] += float(torch.allclose(A.grad,A_mps.grad.cpu(),1e-02, 1e-03))
passed[i] += float(torch.allclose(conv3d.weight.grad,conv3d_mps.weight.grad.cpu(),1e-02, 1e-03))
if(bias):
passed[i] += float(torch.allclose(conv3d.bias.grad,conv3d_mps.bias.grad.cpu(),1e-02, 1e-03))
passed[i] *= 0.25
else:
passed[i] *= 1/3.0
print(passed)
#speed test
print("testing conv3d on mps for speed vs cpu")
A = torch.ones(1,64,32,32,32)
A_mps = A.clone().to('mps')
w = nn.Conv3d(64,64,3,padding=1,bias=False)
w_mps = copy.deepcopy(w).to('mps')
A.requires_grad = True
w.requires_grad = True
for _ in trange(5):
B = w(A)
loss = B.pow(2).mul(.5).mean()
loss.backward()
A_mps.requires_grad = True
w_mps.requires_grad = True
for _ in trange(5):
B_mps = w_mps(A_mps)
loss_mps = B_mps.pow(2).mul(.5).mean()
loss_mps.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment