Created
December 16, 2022 13:04
-
-
Save vivekkhandelwal1/a97bb584de73e9f90b20e957edaca31f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch_mlir | |
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | |
torch.manual_seed(0) | |
grad_out = torch.randn((1, 4, 64, 64)) | |
input_vec = torch.randn((1, 320, 64, 64)) | |
weight = torch.randn((4, 320, 3, 3)) | |
bias_sizes_ = [4] | |
stride_ = [1, 1] | |
padding_ = [1, 1] | |
dilation_ = [1, 1] | |
transposed_ = False | |
output_padding_ = [0, 0] | |
groups_ = 1 | |
output_mask_ = [True, True, True] | |
grad_input, grad_weight, grad_bias = torch.ops.aten.convolution_backward( | |
grad_out, | |
input_vec, | |
weight, | |
bias_sizes=bias_sizes_, | |
stride=stride_, | |
padding=padding_, | |
dilation=dilation_, | |
transposed=transposed_, | |
output_padding=output_padding_, | |
groups=groups_, | |
output_mask=output_mask_, | |
) | |
# print (grad_input, grad_weight, grad_bias) | |
############### DECOMPOSITION GRAD INPUT ############## | |
print("GRAD INPUT") | |
print("PyTorch GRAD INPUT SHAPE", grad_input.shape) | |
outdimA = input_vec.shape[2] | |
outdimB = input_vec.shape[3] | |
gradoutdimA = grad_out.shape[2] | |
gradoutdimB = grad_out.shape[3] | |
weightdimA = torch.ops.aten.floordiv(weight.shape[2], 2) * 2 + 1 | |
weightdimB = torch.ops.aten.floordiv(weight.shape[3], 2) * 2 + 1 | |
decomp_paddingA = torch.ops.aten.ceil( | |
(((outdimA - 1) * stride_[0]) + weightdimA - gradoutdimA) / 2 | |
) | |
decomp_paddingB = torch.ops.aten.ceil( | |
(((outdimB - 1) * stride_[1]) + weightdimB - gradoutdimB) / 2 | |
) | |
decomp_padding_ = [decomp_paddingA, decomp_paddingB] | |
decomp_bias_sizes_ = bias_sizes_ | |
decomp_stride_ = stride_ | |
decomp_dilation_ = dilation_ | |
decomp_transposed_ = transposed_ | |
decomp_output_padding_ = output_padding_ | |
decomp_groups_ = groups_ | |
axes = [2, 3] | |
weight_flip = torch.ops.aten.flip(weight, axes) | |
weight_transposed = torch.ops.aten.transpose(weight_flip, 0, 1) | |
decomp_grad_input = torch.ops.aten.convolution( | |
grad_out, | |
weight_transposed, | |
None, | |
decomp_stride_, | |
decomp_padding_, | |
decomp_dilation_, | |
decomp_transposed_, | |
decomp_output_padding_, | |
decomp_groups_, | |
) | |
print("Decomp GRAD INPUT SHAPE", decomp_grad_input.shape) | |
######### TORCH-MLIR GRAD INPUT ############### | |
class ConvBackwardGradInput(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, grad_outTM, weightTM): | |
return torch.ops.aten.convolution( | |
grad_outTM, | |
weightTM, | |
None, | |
[1, 1], | |
[1, 1], | |
[1, 1], | |
False, | |
[0, 0], | |
1, | |
) | |
conv_backward_GI_class = ConvBackwardGradInput() | |
module = torch_mlir.compile( | |
conv_backward_GI_class, [grad_out, weight_transposed], output_type="linalg-on-tensors" | |
) | |
backend = refbackend.RefBackendLinalgOnTensorsBackend() | |
compiled = backend.compile(module) | |
jit_module = backend.load(compiled) | |
tm_grad_input = torch.from_numpy( | |
jit_module.forward(grad_out.numpy(), weight_transposed.numpy()) | |
) | |
print("T-MLIR GRAD INPUT SHAPE", tm_grad_input.shape) | |
print( | |
"MIN: ", | |
torch.ops.aten.min(grad_input), | |
torch.ops.aten.min(decomp_grad_input), | |
torch.ops.aten.min(tm_grad_input), | |
) | |
print( | |
"MAX: ", | |
torch.ops.aten.max(grad_input), | |
torch.ops.aten.max(decomp_grad_input), | |
torch.ops.aten.max(tm_grad_input), | |
) | |
print( | |
"MEAN: ", | |
torch.ops.aten.mean(grad_input), | |
torch.ops.aten.mean(decomp_grad_input), | |
torch.ops.aten.mean(tm_grad_input), | |
) | |
# grad_input_shape = grad_input.shape | |
# for i in range(grad_input_shape[0]): | |
# for j in range(grad_input_shape[1]): | |
# for m in range(grad_input_shape[2]): | |
# for n in range(grad_input_shape[3]): | |
# print(tm_grad_input[i][j][m][n] == decomp_grad_input[i][j][m][n]) | |
# print(grad_input[i][j][m][n], decomp_grad_input[i][j][m][n], tm_grad_input[i][j][m][n]) | |
############### DECOMPOSITION GRAD WEIGHT ############## | |
print("GRAD WEIGHT") | |
print("PyTorch GRAD WEIGHT SHAPE", grad_weight.shape) | |
# axes = [2, 3] | |
# weight_flip = torch.ops.aten.flip(weight, axes) | |
# decomp_stride_ = [1, 1] | |
decomp_padding_ = padding_ | |
grad_out_transposed = torch.ops.aten.transpose(grad_out, 0, 1) | |
input_vec_transposed = torch.ops.aten.transpose(input_vec, 0, 1) | |
decomp_grad_weight = torch.ops.aten.convolution( | |
input_vec_transposed, | |
grad_out_transposed, | |
None, | |
decomp_stride_, | |
decomp_padding_, | |
decomp_dilation_, | |
decomp_transposed_, | |
decomp_output_padding_, | |
decomp_groups_, | |
) | |
print("Decomp GRAD WEIGHT SHAPE", decomp_grad_weight.shape) | |
######### TORCH-MLIR GRAD WEIGHT ############### | |
class ConvBackwardGradWeight(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input_vecTM, grad_outTM): | |
return torch.ops.aten.convolution( | |
input_vecTM, | |
grad_outTM, | |
None, | |
[1, 1], | |
[1, 1], | |
[1, 1], | |
False, | |
[0, 0], | |
1, | |
) | |
conv_backward_GW_class = ConvBackwardGradWeight() | |
module = torch_mlir.compile( | |
conv_backward_GW_class, [input_vec_transposed, grad_out_transposed], output_type="linalg-on-tensors" | |
) | |
backend = refbackend.RefBackendLinalgOnTensorsBackend() | |
compiled = backend.compile(module) | |
jit_module = backend.load(compiled) | |
tm_grad_weight = torch.from_numpy( | |
jit_module.forward(input_vec_transposed.numpy(), grad_out_transposed.numpy()) | |
) | |
print("T-MLIR GRAD WEIGHT SHAPE", tm_grad_weight.shape) | |
print( | |
"MIN: ", | |
torch.ops.aten.min(grad_weight), | |
torch.ops.aten.min(decomp_grad_weight), | |
torch.ops.aten.min(tm_grad_weight), | |
) | |
print( | |
"MAX: ", | |
torch.ops.aten.max(grad_weight), | |
torch.ops.aten.max(decomp_grad_weight), | |
torch.ops.aten.max(tm_grad_weight), | |
) | |
print( | |
"MEAN: ", | |
torch.ops.aten.mean(grad_weight), | |
torch.ops.aten.mean(decomp_grad_weight), | |
torch.ops.aten.mean(tm_grad_weight), | |
) | |
############### DECOMPOSITION GRAD BIAS ############## | |
print("GRAD BIAS") | |
print("PyTorch GRAD BIAS SHAPE", grad_bias.shape) | |
decomp_grad_bias = torch.ops.aten.sum.dim_IntList(grad_out, [0, 2, 3], keepdim=False) | |
print("Decomp GRAD BIAS SHAPE", decomp_grad_bias.shape) | |
######### TORCH-MLIR GRAD BIAS ############### | |
class ConvBackwardGradBias(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, grad_outTM): | |
return torch.ops.aten.sum(grad_outTM, [0, 2, 3], keepdim=False) | |
conv_backward_GB_class = ConvBackwardGradBias() | |
module = torch_mlir.compile( | |
conv_backward_GB_class, [grad_out], output_type="linalg-on-tensors" | |
) | |
backend = refbackend.RefBackendLinalgOnTensorsBackend() | |
compiled = backend.compile(module) | |
jit_module = backend.load(compiled) | |
tm_grad_bias = torch.from_numpy( | |
jit_module.forward(grad_out.numpy()) | |
) | |
print("T-MLIR GRAD BIAS SHAPE", tm_grad_bias.shape) | |
print( | |
"MIN: ", | |
torch.ops.aten.min(grad_bias), | |
torch.ops.aten.min(decomp_grad_bias), | |
torch.ops.aten.min(tm_grad_bias), | |
) | |
print( | |
"MAX: ", | |
torch.ops.aten.max(grad_bias), | |
torch.ops.aten.max(decomp_grad_bias), | |
torch.ops.aten.max(tm_grad_bias), | |
) | |
print( | |
"MEAN: ", | |
torch.ops.aten.mean(grad_bias), | |
torch.ops.aten.mean(decomp_grad_bias), | |
torch.ops.aten.mean(tm_grad_bias), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment