Last active
March 18, 2022 03:24
-
-
Save ailzhang/0f7182027f129360890d7a6a83c4c360 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
import torch.nn.functional as F | |
import torch_xla | |
def test_conv2d(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, device): | |
input = torch.ones(batch, Cin, input_shape, input_shape, dtype=torch.float, device=device).requires_grad_() | |
weight = torch.ones(Cout, int(Cin / groups), kernel, kernel, dtype=torch.float, device=device).requires_grad_() | |
out = F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups).to(device) | |
out.sum().backward() | |
return (out, input.grad, weight.grad) | |
def test_conv2d_transposed(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, device): | |
input = torch.ones(batch, Cin, input_shape, input_shape, dtype=torch.float, device=device).requires_grad_() | |
weight = torch.ones(Cin, int(Cout / groups), kernel, kernel, dtype=torch.float, device=device).requires_grad_() | |
out = F.conv_transpose2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups).to(device) | |
out.sum().backward() | |
return (out, input.grad, weight.grad) | |
for batch in (1,): | |
for input_shape in range(12, 20): | |
for kernel in range(2, input_shape - 1): | |
for Cin in (3, 9): | |
for Cout in (6, 9): | |
for stride in range(1, 5): | |
for dilation in range(1, math.floor(input_shape/kernel)): | |
for padding in range(0, 10): | |
for groups in (1, 3): | |
print('batch={}, input_shape={}, kernel={}, Cin={}, Cout={}, stride={}, padding={}, dilation={}, groups={}'.format( | |
batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups)) | |
out, input_grad, weight_grad = test_conv2d(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, 'xla:0') | |
out_cuda, input_grad_cuda, weight_grad_cuda = test_conv2d(batch, input_shape, kernel, Cin, Cout, stride, padding, dilation, groups, 'cuda') | |
diff_out = (out.cpu() != out_cuda.cpu()).sum() | |
diff_grad_input = (input_grad.cpu() != input_grad_cuda.cpu()).sum() | |
diff_grad_weight = (weight_grad.cpu() != weight_grad_cuda.cpu()).sum() | |
if diff_out != 0 or diff_grad_input != 0 or diff_grad_weight != 0: | |
print('Diff: out: {}, grad_input: {}, grad_weight: {}'.format(diff_out, diff_grad_input, diff_grad_weight)) | |
for batch in (1, ): | |
for input_shape in range(12, 20): | |
for kernel in range(2, input_shape - 1): | |
for Cin in (3, 9): | |
for Cout in (6, 9): | |
for stride in range(1, 8): | |
for dilation in range(1, math.floor(input_shape / kernel)): | |
for output_padding in range(0, min(stride, dilation)): | |
for groups in (1, 3): | |
print('batch={}, input_shape={}, kernel={}, Cin={}, Cout={}, stride={}, output_padding={}, dilation={}, groups={}'.format( | |
batch, input_shape, kernel, Cin, Cout, stride, output_padding, dilation, groups)) | |
out, input_grad, weight_grad = test_conv2d_transposed(batch, input_shape, kernel, Cin, Cout, stride, output_padding, dilation, groups, 'cpu') | |
out_cuda, input_grad_cuda, weight_grad_cuda = test_conv2d_transposed(batch, input_shape, kernel, Cin, Cout, stride, output_padding, dilation, groups, 'cuda') | |
diff_out = (out.cpu() != out_cuda.cpu()).sum() | |
diff_grad_input = (input_grad.cpu() != input_grad_cuda.cpu()).sum() | |
diff_grad_weight = (weight_grad.cpu() != weight_grad_cuda.cpu()).sum() | |
if diff_out != 0 or diff_grad_input != 0 or diff_grad_weight != 0: | |
print('Diff: out: {}, grad_input: {}, grad_weight: {}'.format(diff_out, diff_grad_input, diff_grad_weight)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment