Created
December 15, 2022 13:27
-
-
Save vivekkhandelwal1/a90abb3c700d31352c7ecd825198599c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch_mlir | |
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | |
torch.manual_seed(0) | |
grad_out = torch.randn((1, 320, 32, 32)) | |
input_vec = torch.randn((1, 320, 64, 64)) | |
weight = torch.randn((320, 320, 3, 3)) | |
bias_sizes_ = [320] | |
stride_ = [2, 2] | |
padding_ = [1, 1] | |
dilation_ = [1, 1] | |
transposed_ = False | |
output_padding_ = [0, 0] | |
groups_ = 1 | |
output_mask_ = [True, True, True] | |
grad_input, grad_weight, grad_bias = torch.ops.aten.convolution_backward( | |
grad_out, | |
input_vec, | |
weight, | |
bias_sizes=bias_sizes_, | |
stride=stride_, | |
padding=padding_, | |
dilation=dilation_, | |
transposed=transposed_, | |
output_padding=output_padding_, | |
groups=groups_, | |
output_mask=output_mask_, | |
) | |
# print (grad_input, grad_weight, grad_bias) | |
############### DECOMPOSITION GRAD INPUT ############## | |
print("GRAD INPUT") | |
print("PyTorch GRAD INPUT SHAPE", grad_input.shape) | |
outdimA = input_vec.shape[2] | |
outdimB = input_vec.shape[3] | |
gradoutdimA = grad_out.shape[2] | |
gradoutdimB = grad_out.shape[3] | |
weightdimA = torch.ops.aten.floordiv(weight.shape[2], 2) * 2 + 1 | |
weightdimB = torch.ops.aten.floordiv(weight.shape[3], 2) * 2 + 1 | |
decomp_paddingA = torch.ops.aten.ceil( | |
(((outdimA - 1) * stride_[0]) + weightdimA - gradoutdimA) / 2 | |
) | |
decomp_paddingB = torch.ops.aten.ceil( | |
(((outdimB - 1) * stride_[1]) + weightdimB - gradoutdimB) / 2 | |
) | |
decomp_padding_ = [decomp_paddingA, decomp_paddingB] | |
decomp_bias_sizes_ = bias_sizes_ | |
decomp_stride_ = stride_ | |
decomp_dilation_ = dilation_ | |
decomp_transposed_ = transposed_ | |
decomp_output_padding_ = output_padding_ | |
decomp_groups_ = groups_ | |
axes = [2, 3] | |
weight_flip = torch.ops.aten.flip(weight, axes) | |
weight_transposed = torch.ops.aten.transpose(weight_flip, 0, 1) | |
decomp_grad_input = torch.ops.aten.convolution( | |
grad_out, | |
weight_transposed, | |
None, | |
decomp_stride_, | |
decomp_padding_, | |
decomp_dilation_, | |
decomp_transposed_, | |
decomp_output_padding_, | |
decomp_groups_, | |
) | |
print("Decomp GRAD INPUT SHAPE", decomp_grad_input.shape) | |
######### TORCH-MLIR GRAD INPUT ############### | |
class ConvBackwardGradInput(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, grad_outTM, weightTM): | |
return torch.ops.aten.convolution( | |
grad_outTM, | |
weightTM, | |
None, | |
[2, 2], | |
[49, 49], | |
[1, 1], | |
False, | |
[0, 0], | |
1, | |
) | |
conv_backward_GI_class = ConvBackwardGradInput() | |
module = torch_mlir.compile( | |
conv_backward_GI_class, [grad_out, weight_transposed], output_type="linalg-on-tensors" | |
) | |
backend = refbackend.RefBackendLinalgOnTensorsBackend() | |
compiled = backend.compile(module) | |
jit_module = backend.load(compiled) | |
tm_grad_input = torch.from_numpy( | |
jit_module.forward(grad_out.numpy(), weight_transposed.numpy()) | |
) | |
print("T-MLIR GRAD INPUT SHAPE", tm_grad_input.shape) | |
print( | |
"MIN: ", | |
torch.ops.aten.min(grad_input), | |
torch.ops.aten.min(decomp_grad_input), | |
torch.ops.aten.min(tm_grad_input), | |
) | |
print( | |
"MAX: ", | |
torch.ops.aten.max(grad_input), | |
torch.ops.aten.max(decomp_grad_input), | |
torch.ops.aten.max(tm_grad_input), | |
) | |
print( | |
"MEAN: ", | |
torch.ops.aten.mean(grad_input), | |
torch.ops.aten.mean(decomp_grad_input), | |
torch.ops.aten.mean(tm_grad_input), | |
) | |
# grad_input_shape = grad_input.shape | |
# for i in range(grad_input_shape[0]): | |
# for j in range(grad_input_shape[1]): | |
# for m in range(grad_input_shape[2]): | |
# for n in range(grad_input_shape[3]): | |
# print(tm_grad_input[i][j][m][n] == decomp_grad_input[i][j][m][n]) | |
# print(grad_input[i][j][m][n], decomp_grad_input[i][j][m][n], tm_grad_input[i][j][m][n]) | |
############### DECOMPOSITION GRAD WEIGHT ############## | |
print("GRAD WEIGHT") | |
print("PyTorch GRAD WEIGHT SHAPE", grad_weight.shape) | |
# axes = [2, 3] | |
# weight_flip = torch.ops.aten.flip(weight, axes) | |
decomp_stride_ = [12, 12] | |
decomp_padding_ = padding_ | |
grad_out_transposed = torch.ops.aten.transpose(grad_out, 0, 1) | |
input_vec_transposed = torch.ops.aten.transpose(input_vec, 0, 1) | |
decomp_grad_weight = torch.ops.aten.convolution( | |
input_vec_transposed, | |
grad_out_transposed, | |
None, | |
decomp_stride_, | |
decomp_padding_, | |
decomp_dilation_, | |
decomp_transposed_, | |
decomp_output_padding_, | |
decomp_groups_, | |
) | |
print("Decomp GRAD WEIGHT SHAPE", decomp_grad_weight.shape) | |
######### TORCH-MLIR GRAD WEIGHT ############### | |
class ConvBackwardGradWeight(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input_vecTM, grad_outTM): | |
return torch.ops.aten.convolution( | |
input_vecTM, | |
grad_outTM, | |
None, | |
[12, 12], | |
[1, 1], | |
[1, 1], | |
False, | |
[0, 0], | |
1, | |
) | |
conv_backward_GW_class = ConvBackwardGradWeight() | |
module = torch_mlir.compile( | |
conv_backward_GW_class, [input_vec_transposed, grad_out_transposed], output_type="linalg-on-tensors" | |
) | |
backend = refbackend.RefBackendLinalgOnTensorsBackend() | |
compiled = backend.compile(module) | |
jit_module = backend.load(compiled) | |
tm_grad_weight = torch.from_numpy( | |
jit_module.forward(input_vec_transposed.numpy(), grad_out_transposed.numpy()) | |
) | |
print("T-MLIR GRAD WEIGHT SHAPE", tm_grad_weight.shape) | |
print( | |
"MIN: ", | |
torch.ops.aten.min(grad_weight), | |
torch.ops.aten.min(decomp_grad_weight), | |
torch.ops.aten.min(tm_grad_weight), | |
) | |
print( | |
"MAX: ", | |
torch.ops.aten.max(grad_weight), | |
torch.ops.aten.max(decomp_grad_weight), | |
torch.ops.aten.max(tm_grad_weight), | |
) | |
print( | |
"MEAN: ", | |
torch.ops.aten.mean(grad_weight), | |
torch.ops.aten.mean(decomp_grad_weight), | |
torch.ops.aten.mean(tm_grad_weight), | |
) | |
############### DECOMPOSITION GRAD BIAS ############## | |
print("GRAD BIAS") | |
print("PyTorch GRAD BIAS SHAPE", grad_bias.shape) | |
decomp_grad_bias = torch.ops.aten.sum.dim_IntList(grad_out, [0, 2, 3], keepdim=False) | |
print("Decomp GRAD BIAS SHAPE", decomp_grad_bias.shape) | |
######### TORCH-MLIR GRAD BIAS ############### | |
class ConvBackwardGradBias(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, grad_outTM): | |
return torch.ops.aten.sum(grad_outTM, [0, 2, 3], keepdim=False) | |
conv_backward_GB_class = ConvBackwardGradBias() | |
module = torch_mlir.compile( | |
conv_backward_GB_class, [grad_out], output_type="linalg-on-tensors" | |
) | |
backend = refbackend.RefBackendLinalgOnTensorsBackend() | |
compiled = backend.compile(module) | |
jit_module = backend.load(compiled) | |
tm_grad_bias = torch.from_numpy( | |
jit_module.forward(grad_out.numpy()) | |
) | |
print("T-MLIR GRAD BIAS SHAPE", tm_grad_bias.shape) | |
print( | |
"MIN: ", | |
torch.ops.aten.min(grad_bias), | |
torch.ops.aten.min(decomp_grad_bias), | |
torch.ops.aten.min(tm_grad_bias), | |
) | |
print( | |
"MAX: ", | |
torch.ops.aten.max(grad_bias), | |
torch.ops.aten.max(decomp_grad_bias), | |
torch.ops.aten.max(tm_grad_bias), | |
) | |
print( | |
"MEAN: ", | |
torch.ops.aten.mean(grad_bias), | |
torch.ops.aten.mean(decomp_grad_bias), | |
torch.ops.aten.mean(tm_grad_bias), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment