Skip to content

Instantly share code, notes, and snippets.

@vivekkhandelwal1
Created December 16, 2022 13:04
Show Gist options
  • Save vivekkhandelwal1/a97bb584de73e9f90b20e957edaca31f to your computer and use it in GitHub Desktop.
Save vivekkhandelwal1/a97bb584de73e9f90b20e957edaca31f to your computer and use it in GitHub Desktop.
import torch
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
torch.manual_seed(0)
grad_out = torch.randn((1, 4, 64, 64))
input_vec = torch.randn((1, 320, 64, 64))
weight = torch.randn((4, 320, 3, 3))
bias_sizes_ = [4]
stride_ = [1, 1]
padding_ = [1, 1]
dilation_ = [1, 1]
transposed_ = False
output_padding_ = [0, 0]
groups_ = 1
output_mask_ = [True, True, True]
grad_input, grad_weight, grad_bias = torch.ops.aten.convolution_backward(
grad_out,
input_vec,
weight,
bias_sizes=bias_sizes_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=transposed_,
output_padding=output_padding_,
groups=groups_,
output_mask=output_mask_,
)
# print (grad_input, grad_weight, grad_bias)
############### DECOMPOSITION GRAD INPUT ##############
print("GRAD INPUT")
print("PyTorch GRAD INPUT SHAPE", grad_input.shape)
outdimA = input_vec.shape[2]
outdimB = input_vec.shape[3]
gradoutdimA = grad_out.shape[2]
gradoutdimB = grad_out.shape[3]
weightdimA = torch.ops.aten.floordiv(weight.shape[2], 2) * 2 + 1
weightdimB = torch.ops.aten.floordiv(weight.shape[3], 2) * 2 + 1
decomp_paddingA = torch.ops.aten.ceil(
(((outdimA - 1) * stride_[0]) + weightdimA - gradoutdimA) / 2
)
decomp_paddingB = torch.ops.aten.ceil(
(((outdimB - 1) * stride_[1]) + weightdimB - gradoutdimB) / 2
)
decomp_padding_ = [decomp_paddingA, decomp_paddingB]
decomp_bias_sizes_ = bias_sizes_
decomp_stride_ = stride_
decomp_dilation_ = dilation_
decomp_transposed_ = transposed_
decomp_output_padding_ = output_padding_
decomp_groups_ = groups_
axes = [2, 3]
weight_flip = torch.ops.aten.flip(weight, axes)
weight_transposed = torch.ops.aten.transpose(weight_flip, 0, 1)
decomp_grad_input = torch.ops.aten.convolution(
grad_out,
weight_transposed,
None,
decomp_stride_,
decomp_padding_,
decomp_dilation_,
decomp_transposed_,
decomp_output_padding_,
decomp_groups_,
)
print("Decomp GRAD INPUT SHAPE", decomp_grad_input.shape)
######### TORCH-MLIR GRAD INPUT ###############
class ConvBackwardGradInput(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, grad_outTM, weightTM):
return torch.ops.aten.convolution(
grad_outTM,
weightTM,
None,
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
)
conv_backward_GI_class = ConvBackwardGradInput()
module = torch_mlir.compile(
conv_backward_GI_class, [grad_out, weight_transposed], output_type="linalg-on-tensors"
)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
tm_grad_input = torch.from_numpy(
jit_module.forward(grad_out.numpy(), weight_transposed.numpy())
)
print("T-MLIR GRAD INPUT SHAPE", tm_grad_input.shape)
print(
"MIN: ",
torch.ops.aten.min(grad_input),
torch.ops.aten.min(decomp_grad_input),
torch.ops.aten.min(tm_grad_input),
)
print(
"MAX: ",
torch.ops.aten.max(grad_input),
torch.ops.aten.max(decomp_grad_input),
torch.ops.aten.max(tm_grad_input),
)
print(
"MEAN: ",
torch.ops.aten.mean(grad_input),
torch.ops.aten.mean(decomp_grad_input),
torch.ops.aten.mean(tm_grad_input),
)
# grad_input_shape = grad_input.shape
# for i in range(grad_input_shape[0]):
# for j in range(grad_input_shape[1]):
# for m in range(grad_input_shape[2]):
# for n in range(grad_input_shape[3]):
# print(tm_grad_input[i][j][m][n] == decomp_grad_input[i][j][m][n])
# print(grad_input[i][j][m][n], decomp_grad_input[i][j][m][n], tm_grad_input[i][j][m][n])
############### DECOMPOSITION GRAD WEIGHT ##############
print("GRAD WEIGHT")
print("PyTorch GRAD WEIGHT SHAPE", grad_weight.shape)
# axes = [2, 3]
# weight_flip = torch.ops.aten.flip(weight, axes)
# decomp_stride_ = [1, 1]
decomp_padding_ = padding_
grad_out_transposed = torch.ops.aten.transpose(grad_out, 0, 1)
input_vec_transposed = torch.ops.aten.transpose(input_vec, 0, 1)
decomp_grad_weight = torch.ops.aten.convolution(
input_vec_transposed,
grad_out_transposed,
None,
decomp_stride_,
decomp_padding_,
decomp_dilation_,
decomp_transposed_,
decomp_output_padding_,
decomp_groups_,
)
print("Decomp GRAD WEIGHT SHAPE", decomp_grad_weight.shape)
######### TORCH-MLIR GRAD WEIGHT ###############
class ConvBackwardGradWeight(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_vecTM, grad_outTM):
return torch.ops.aten.convolution(
input_vecTM,
grad_outTM,
None,
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
)
conv_backward_GW_class = ConvBackwardGradWeight()
module = torch_mlir.compile(
conv_backward_GW_class, [input_vec_transposed, grad_out_transposed], output_type="linalg-on-tensors"
)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
tm_grad_weight = torch.from_numpy(
jit_module.forward(input_vec_transposed.numpy(), grad_out_transposed.numpy())
)
print("T-MLIR GRAD WEIGHT SHAPE", tm_grad_weight.shape)
print(
"MIN: ",
torch.ops.aten.min(grad_weight),
torch.ops.aten.min(decomp_grad_weight),
torch.ops.aten.min(tm_grad_weight),
)
print(
"MAX: ",
torch.ops.aten.max(grad_weight),
torch.ops.aten.max(decomp_grad_weight),
torch.ops.aten.max(tm_grad_weight),
)
print(
"MEAN: ",
torch.ops.aten.mean(grad_weight),
torch.ops.aten.mean(decomp_grad_weight),
torch.ops.aten.mean(tm_grad_weight),
)
############### DECOMPOSITION GRAD BIAS ##############
print("GRAD BIAS")
print("PyTorch GRAD BIAS SHAPE", grad_bias.shape)
decomp_grad_bias = torch.ops.aten.sum.dim_IntList(grad_out, [0, 2, 3], keepdim=False)
print("Decomp GRAD BIAS SHAPE", decomp_grad_bias.shape)
######### TORCH-MLIR GRAD BIAS ###############
class ConvBackwardGradBias(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, grad_outTM):
return torch.ops.aten.sum(grad_outTM, [0, 2, 3], keepdim=False)
conv_backward_GB_class = ConvBackwardGradBias()
module = torch_mlir.compile(
conv_backward_GB_class, [grad_out], output_type="linalg-on-tensors"
)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
tm_grad_bias = torch.from_numpy(
jit_module.forward(grad_out.numpy())
)
print("T-MLIR GRAD BIAS SHAPE", tm_grad_bias.shape)
print(
"MIN: ",
torch.ops.aten.min(grad_bias),
torch.ops.aten.min(decomp_grad_bias),
torch.ops.aten.min(tm_grad_bias),
)
print(
"MAX: ",
torch.ops.aten.max(grad_bias),
torch.ops.aten.max(decomp_grad_bias),
torch.ops.aten.max(tm_grad_bias),
)
print(
"MEAN: ",
torch.ops.aten.mean(grad_bias),
torch.ops.aten.mean(decomp_grad_bias),
torch.ops.aten.mean(tm_grad_bias),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment