Skip to content

Instantly share code, notes, and snippets.

@qingcd
Created March 10, 2022 03:12
Show Gist options
  • Save qingcd/6d9d228d92a7b6d09732a6070473a229 to your computer and use it in GitHub Desktop.
Save qingcd/6d9d228d92a7b6d09732a6070473a229 to your computer and use it in GitHub Desktop.
The results of multi run are different of a tvm model.
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
class BatchActivateConvLayer(nn.Module):
def __init__(
self, channel_in, growth_rate, bottleneck_size_basic_factor, drop_ratio=0.8
):
super(BatchActivateConvLayer, self).__init__()
self.drop_ratio = drop_ratio
self.growth_rate = growth_rate
self.bottleneck_channel_out = bottleneck_size_basic_factor * growth_rate
self.mode_bn = torch.nn.BatchNorm3d(channel_in)
self.mode_conv = nn.Conv3d(
channel_in, self.bottleneck_channel_out, kernel_size=1, stride=1, bias=False
)
self.bn = torch.nn.BatchNorm3d(self.bottleneck_channel_out)
self.conv = nn.Conv3d(
self.bottleneck_channel_out,
growth_rate,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.drop_out = nn.Dropout3d(p=self.drop_ratio)
def forward(self, x):
current = x
current = self.mode_bn(current)
current = self.mode_conv(current)
current = self.bn(current)
current = self.conv(current)
if self.drop_ratio > 0:
current = self.drop_out(current)
return current
class DenseBlock(nn.Module):
def __init__(
self,
current_block_layers_number,
channel_in,
growth_rate,
bottleneck_size_basic_factor,
drop_ratio=0.8,
):
super(DenseBlock, self).__init__()
self.channel_in = channel_in
self.growth_rate = growth_rate
self.bottleneck_size_basic_factor = bottleneck_size_basic_factor
self.current_channel_in = self.channel_in
self.current_blcok_drop_ratio = drop_ratio
self.current_block_layer_number = current_block_layers_number
for i in range(self.current_block_layer_number):
current_block_layers = BatchActivateConvLayer(
self.current_channel_in,
self.growth_rate,
self.bottleneck_size_basic_factor,
self.current_blcok_drop_ratio,
)
setattr(self, "block_layer_" + str(i), current_block_layers)
self.current_channel_in += self.growth_rate
def get_current_block_channel_out(self):
return self.current_channel_in
def forward(self, x):
current = x
for i in range(self.current_block_layer_number):
current_clone = current.clone()
tmp = getattr(self, "block_layer_" + str(i))(current_clone)
current = torch.cat((current, tmp), 1)
return current
class DenseNet(nn.Module):
def __init__(
self,
growth_rate=24,
block_config=(2, 2),
compression=0.5,
num_init_features=24,
bottleneck_size_basic_factor=2,
drop_rate=0,
num_classes=2,
small_inputs=True,
rnn_units=512,
):
super(DenseNet, self).__init__()
self.features = nn.Conv3d(
1, num_init_features, kernel_size=3, stride=1, padding=1, bias=False
)
self.init_feature_channel_number = num_init_features
self.growth_rate = growth_rate
self.compression = compression
self.number_class = num_classes
self.block_config = block_config
self.rnn_units = rnn_units
self.drop_ratio = drop_rate
num_features = num_init_features
self.dense_trainsition_out_put_list = []
for i, num_layers in enumerate(self.block_config):
block = DenseBlock(
num_layers,
num_features,
self.growth_rate,
bottleneck_size_basic_factor,
drop_rate,
)
setattr(self, "block_" + str(i), block)
num_features = num_features + num_layers * growth_rate
self.dense_trainsition_out_put_list.append(num_features)
for name, param in self.named_parameters():
if "conv" in name and "weight" in name:
n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
param.data.normal_().mul_(math.sqrt(2.0 / n))
elif "norm" in name and "weight" in name:
param.data.fill_(1)
elif "norm" in name and "bias" in name:
param.data.fill_(0)
def forward(self, x):
features = self.features(x[:, :1])
for i in range(len(self.block_config)):
features = getattr(self, "block_" + str(i))(features)
return features
def run_tvm_module(module, inpt):
module.set_input(0, inpt)
module.run()
tvm.cuda().sync()
res = module.get_output(0).numpy()
return res
if __name__ == "__main__":
model = DenseNet()
model.eval()
model_jit = torch.jit.trace(model, example_inputs=torch.randn((4,2,64,64,64)))
print("finish gen trace model")
relay_model, params = relay.frontend.from_pytorch(
model_jit, [('input_0', (4,2,64,64,64))], default_dtype='float32')
target = tvm.target.cuda()
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(relay_model, target=target, params=params)
lib.export_library('./dense.so')
del lib
print("finish compile tvm model")
inpt = np.random.random((4,2,64,64,64))
lib = tvm.runtime.load_module('./dense.so')
module = graph_executor.GraphModule(lib["default"](tvm.cuda()))
res1 = run_tvm_module(module, inpt)
res2 = run_tvm_module(module, inpt)
diff = res1 - res2
print("max abs diff is:", np.max(np.abs(diff)))
# change the target_fmt from cubin to ptx in python/tvm/contrib/nvcc.py
@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment