Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created November 10, 2021 02:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lyken17/b2fab3902ffa2df9d380ca788ddd175f to your computer and use it in GitHub Desktop.
Save Lyken17/b2fab3902ffa2df9d380ca788ddd175f to your computer and use it in GitHub Desktop.
conv2d_gradient_tvm
import numpy as np
import tvm
from tvm import relay
from tvm import relay, auto_scheduler
from tvm.relay import testing
SEMVER = '#[version = "0.0.5"]\n'
def assert_graph_equal(lhs, rhs):
tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
def roundtrip(expr):
x = tvm.parser.fromtext(expr.astext())
assert_graph_equal(x, expr)
# Testing Utilities for full modules.
def parse_module(code):
mod = tvm.parser.parse(SEMVER + code)
roundtrip(mod)
return mod
program = """
def @main(%input0: Tensor[(1, 16, 224, 224), float32],
%v0_0_weight: Tensor[(16, 32, 3, 3), float32]) -> Tensor[(1, 32, 224, 224), float32] {
/* test comment */
%0 = nn.conv2d_transpose(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=1, channels=32, kernel_size=[3, 3]);
%0
}
"""
program = """
def @main(%input0: Tensor[(1, 32, 224, 224), float32],
%v0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 32, 224, 224), float32] {
%0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
%0
}
"""
mod = parse_module(program)
mod = relay.transform.InferType()(mod)
target = "llvm"
lib = relay.build(mod, target=target, params=None)
print("build [fwd] pass successful")
mod = relay.transform.InferType()(mod)
bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_ir)
print(bwd_mod)
lib = relay.build(bwd_mod, target=target, params=None)
print("build [fwd] pass successful")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment