Skip to content

Instantly share code, notes, and snippets.

@masahi
Last active December 14, 2020 00:39
Show Gist options
  • Save masahi/a449fb1cea83a86a44cf20da210ec5ef to your computer and use it in GitHub Desktop.
Save masahi/a449fb1cea83a86a44cf20da210ec5ef to your computer and use it in GitHub Desktop.
import logging
import sys
import numpy as np
import tvm
import tvm.topi.testing
from tvm import te, testing
from tvm.topi.utils import get_const_tuple
from tvm import autotvm, topi
batch = 1
in_size = 56
in_channel = 64
num_filter = 64
kernel = 3
stride = 1
padding = 1
activation_bits = 1
weight_bits = 1
unipolar = True
in_height = in_width = in_size
input_dtype = "uint32"
out_dtype = "int32"
def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
return a_np, w_np
@autotvm.template("bitserial_conv2d")
def bit_serial_conv():
# A = te.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name="A")
# W = te.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name="W")
# B = topi.x86.bitserial_conv2d_nhwc(
# A, W, stride, padding, activation_bits, weight_bits, input_dtype, out_dtype, unipolar
# )
# s = topi.x86.schedule_bitserial_conv2d_nhwc([B])
A = te.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name="A")
W = te.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name="W")
B = topi.x86.bitserial_conv2d_nchw(
A, W, stride, padding, activation_bits, weight_bits, input_dtype, out_dtype, unipolar
)
s = topi.x86.schedule_bitserial_conv2d_nchw([B])
return s, [A, W, B]
target = "llvm -mcpu=icelake-client"
# target = "llvm -mcpu=cascadelake"
task = autotvm.task.create("bitserial_conv2d", args=(), target=target)
print(task.config_space)
logging.getLogger("autotvm").setLevel(logging.DEBUG)
logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout))
measure_option = autotvm.measure_option(builder="local", runner=autotvm.LocalRunner(number=5))
tuner = autotvm.tuner.RandomTuner(task)
# tuner = autotvm.tuner.XGBTuner(task, loss_type="rank")
log_file = "bit_serial.log"
tuner.tune(
n_trial=100,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file(log_file)],
)
with autotvm.apply_history_best(log_file):
with tvm.target.Target(target):
s, arg_bufs = bit_serial_conv()
func = tvm.build(s, arg_bufs)
A, W, B = arg_bufs
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
a_np, w_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], target)
func(a, w, b)
ftimer = func.time_evaluator(func.entry_name, ctx, number=1, repeat=100)
prof_res = np.array(ftimer(a, w, b).results) * 1000 # multiply 1000 for converting to millisecond
print(prof_res.mean())
# print(func.get_source())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment