Created
November 10, 2021 02:34
-
-
Save Lyken17/b2fab3902ffa2df9d380ca788ddd175f to your computer and use it in GitHub Desktop.
conv2d_gradient_tvm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tvm | |
from tvm import relay | |
from tvm import relay, auto_scheduler | |
from tvm.relay import testing | |
SEMVER = '#[version = "0.0.5"]\n' | |
def assert_graph_equal(lhs, rhs): | |
tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True) | |
def roundtrip(expr): | |
x = tvm.parser.fromtext(expr.astext()) | |
assert_graph_equal(x, expr) | |
# Testing Utilities for full modules. | |
def parse_module(code): | |
mod = tvm.parser.parse(SEMVER + code) | |
roundtrip(mod) | |
return mod | |
program = """ | |
def @main(%input0: Tensor[(1, 16, 224, 224), float32], | |
%v0_0_weight: Tensor[(16, 32, 3, 3), float32]) -> Tensor[(1, 32, 224, 224), float32] { | |
/* test comment */ | |
%0 = nn.conv2d_transpose(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=1, channels=32, kernel_size=[3, 3]); | |
%0 | |
} | |
""" | |
program = """ | |
def @main(%input0: Tensor[(1, 32, 224, 224), float32], | |
%v0_0_weight: Tensor[(32, 1, 3, 3), float32], | |
%v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], | |
%v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 32, 224, 224), float32] { | |
%0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]); | |
%0 | |
} | |
""" | |
mod = parse_module(program) | |
mod = relay.transform.InferType()(mod) | |
target = "llvm" | |
lib = relay.build(mod, target=target, params=None) | |
print("build [fwd] pass successful") | |
mod = relay.transform.InferType()(mod) | |
bwd_ir = relay.transform.gradient(mod['main'], mode="first_order") | |
bwd_mod = tvm.IRModule.from_expr(bwd_ir) | |
print(bwd_mod) | |
lib = relay.build(bwd_mod, target=target, params=None) | |
print("build [fwd] pass successful") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment