Skip to content

Instantly share code, notes, and snippets.

@ailzhang
Last active March 18, 2022 03:24
Show Gist options
  • Save ailzhang/0f7182027f129360890d7a6a83c4c360 to your computer and use it in GitHub Desktop.
Save ailzhang/0f7182027f129360890d7a6a83c4c360 to your computer and use it in GitHub Desktop.
import torch
import math
import torch.nn.functional as F
import torch_xla
def test_conv2d(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, device):
input = torch.ones(batch, Cin, input_shape, input_shape, dtype=torch.float, device=device).requires_grad_()
weight = torch.ones(Cout, int(Cin / groups), kernel, kernel, dtype=torch.float, device=device).requires_grad_()
out = F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups).to(device)
out.sum().backward()
return (out, input.grad, weight.grad)
def test_conv2d_transposed(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, device):
input = torch.ones(batch, Cin, input_shape, input_shape, dtype=torch.float, device=device).requires_grad_()
weight = torch.ones(Cin, int(Cout / groups), kernel, kernel, dtype=torch.float, device=device).requires_grad_()
out = F.conv_transpose2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups).to(device)
out.sum().backward()
return (out, input.grad, weight.grad)
for batch in (1,):
for input_shape in range(12, 20):
for kernel in range(2, input_shape - 1):
for Cin in (3, 9):
for Cout in (6, 9):
for stride in range(1, 5):
for dilation in range(1, math.floor(input_shape/kernel)):
for padding in range(0, 10):
for groups in (1, 3):
print('batch={}, input_shape={}, kernel={}, Cin={}, Cout={}, stride={}, padding={}, dilation={}, groups={}'.format(
batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups))
out, input_grad, weight_grad = test_conv2d(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, 'xla:0')
out_cuda, input_grad_cuda, weight_grad_cuda = test_conv2d(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, 'cuda')
diff_out = (out.cpu() != out_cuda.cpu()).sum()
diff_grad_input = (input_grad.cpu() != input_grad_cuda.cpu()).sum()
diff_grad_weight = (weight_grad.cpu() != weight_grad_cuda.cpu()).sum()
if diff_out != 0 or diff_grad_input != 0 or diff_grad_weight != 0:
print('Diff: out: {}, grad_input: {}, grad_weight: {}'.format(diff_out, diff_grad_input, diff_grad_weight))
for batch in (1, ):
for input_shape in range(12, 20):
for kernel in range(2, input_shape - 1):
for Cin in (3, 9):
for Cout in (6, 9):
for stride in range(1, 8):
for dilation in range(1, math.floor(input_shape / kernel)):
for output_padding in range(0, min(stride, dilation)):
for groups in (1, 3):
print('batch={}, input_shape={}, kernel={}, Cin={}, Cout={}, stride={}, output_padding={}, dilation={}, groups={}'.format(
batch, input_shape, kernel, Cin, Cout, stride, output_padding, dilation, groups))
out, input_grad, weight_grad = test_conv2d_transposed(batch, input_shape, kernel, Cin, Cout, stride, output_padding, dilation, groups, 'cpu')
out_cuda, input_grad_cuda, weight_grad_cuda = test_conv2d_transposed(batch, input_shape, kernel, Cin, Cout, stride, output_padding, dilation, groups, 'cuda')
diff_out = (out.cpu() != out_cuda.cpu()).sum()
diff_grad_input = (input_grad.cpu() != input_grad_cuda.cpu()).sum()
diff_grad_weight = (weight_grad.cpu() != weight_grad_cuda.cpu()).sum()
if diff_out != 0 or diff_grad_input != 0 or diff_grad_weight != 0:
print('Diff: out: {}, grad_input: {}, grad_weight: {}'.format(diff_out, diff_grad_input, diff_grad_weight))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment