Skip to content

Instantly share code, notes, and snippets.

@icemelon
Last active December 29, 2022 04:09
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save icemelon/860d3d2c9566d6f69fa8112840dd95c1 to your computer and use it in GitHub Desktop.
Save icemelon/860d3d2c9566d6f69fa8112840dd95c1 to your computer and use it in GitHub Desktop.
Optimize the BERT model on CPUs
import time
import argparse
import numpy as np
import mxnet as mx
import gluonnlp as nlp
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime
def timer(thunk, repeat=1, number=10, dryrun=3, min_repeat_ms=1000):
"""Helper function to time a function"""
for i in range(dryrun):
thunk()
ret = []
for _ in range(repeat):
while True:
beg = time.time()
for _ in range(number):
thunk()
end = time.time()
lat = (end - beg) * 1e3
if lat >= min_repeat_ms:
break
number = int(max(min_repeat_ms / (lat / number) + 1, number * 1.618))
ret.append(lat / number)
return ret
parser = argparse.ArgumentParser(description="Optimize BERT-base model from GluonNLP")
parser.add_argument("-b", "--batch", type=int, default=1,
help="Batch size (default: 1)")
parser.add_argument("-l", "--length", type=int, default=128,
help="Sequence length (default: 128)")
args = parser.parse_args()
batch = args.batch
seq_length = args.length
# Instantiate a BERT classifier using GluonNLP
model_name = 'bert_12_768_12'
dataset = 'book_corpus_wiki_en_uncased'
mx_ctx = mx.cpu()
bert, _ = nlp.model.get_model(
name=model_name,
ctx=mx_ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model = nlp.model.BERTClassifier(bert, dropout=0.1, num_classes=2)
model.initialize(ctx=mx_ctx)
model.hybridize(static_alloc=True)
# Prepare input data
dtype = "float32"
inputs = np.random.randint(0, 2000, size=(batch, seq_length)).astype(dtype)
token_types = np.random.uniform(size=(batch, seq_length)).astype(dtype)
valid_length = np.asarray([seq_length] * batch).astype(dtype)
# Convert to MXNet NDArray and run the MXNet model
inputs_nd = mx.nd.array(inputs, ctx=mx_ctx)
token_types_nd = mx.nd.array(token_types, ctx=mx_ctx)
valid_length_nd = mx.nd.array(valid_length, ctx=mx_ctx)
mx_out = model(inputs_nd, token_types_nd, valid_length_nd)
mx_out.wait_to_read()
# Benchmark the MXNet latency
res = timer(lambda: model(inputs_nd, token_types_nd, valid_length_nd).wait_to_read(),
repeat=3,
dryrun=5,
min_repeat_ms=1000)
print(f"MXNet latency for batch {batch} and seq length {seq_length}: {np.mean(res):.2f} ms")
######################################
# Optimize the BERT model using TVM
######################################
# First, Convert the MXNet model into TVM Relay format
shape_dict = {
'data0': (batch, seq_length),
'data1': (batch, seq_length),
'data2': (batch,)
}
mod, params = relay.frontend.from_mxnet(model, shape_dict)
# Compile the imported model
target = "llvm -mcpu=skylake-avx512 -libs=cblas"
with relay.build_config(opt_level=3, required_pass=["FastMath"]):
graph, lib, cparams = relay.build(mod, target, params=params)
# Create the executor and set the parameters and inputs
ctx = tvm.cpu()
rt = runtime.create(graph, lib, ctx)
rt.set_input(**cparams)
rt.set_input(data0=inputs, data1=token_types, data2=valid_length)
# Run the executor and validate the correctness
rt.run()
out = rt.get_output(0)
tvm.testing.assert_allclose(out.asnumpy(), mx_out.asnumpy(), rtol=1e-3, atol=1e-3)
# Benchmark the TVM latency
ftimer = rt.module.time_evaluator("run", ctx, repeat=3, min_repeat_ms=1000)
prof_res = np.array(ftimer().results) * 1000
print(f"TVM latency for batch {batch} and seq length {seq_length}: {np.mean(prof_res):.2f} ms")
import time
import argparse
import numpy as np
import mxnet as mx
import gluonnlp as nlp
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime
def timer(thunk, repeat=1, number=10, dryrun=3, min_repeat_ms=1000):
"""Helper function to time a function"""
for i in range(dryrun):
thunk()
ret = []
for _ in range(repeat):
while True:
beg = time.time()
for _ in range(number):
thunk()
end = time.time()
lat = (end - beg) * 1e3
if lat >= min_repeat_ms:
break
number = int(max(min_repeat_ms / (lat / number) + 1, number * 1.618))
ret.append(lat / number)
return ret
parser = argparse.ArgumentParser(description="Optimize DistilBERT model from GluonNLP")
parser.add_argument("-b", "--batch", type=int, default=1,
help="Batch size (default: 1)")
parser.add_argument("-l", "--length", type=int, default=128,
help="Sequence length (default: 128)")
args = parser.parse_args()
batch = args.batch
seq_length = args.length
# Instantiate a BERT classifier using GluonNLP
model_name = 'distilbert_6_768_12'
dataset = 'distilbert_book_corpus_wiki_en_uncased'
mx_ctx = mx.cpu()
bert, _ = nlp.model.get_model(
name=model_name,
ctx=mx_ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=False,
use_decoder=False,
use_classifier=False)
model = nlp.model.RoBERTaClassifier(bert, dropout=0.1, num_classes=2)
model.initialize(ctx=mx_ctx)
model.hybridize(static_alloc=True)
# Prepare input data
dtype = "float32"
inputs = np.random.randint(0, 2000, size=(batch, seq_length)).astype(dtype)
valid_length = np.asarray([seq_length] * batch).astype(dtype)
# Convert to MXNet NDArray and run the MXNet model
inputs_nd = mx.nd.array(inputs, ctx=mx_ctx)
valid_length_nd = mx.nd.array(valid_length, ctx=mx_ctx)
mx_out = model(inputs_nd, valid_length_nd)
mx_out.wait_to_read()
# Benchmark the MXNet latency
res = timer(lambda: model(inputs_nd, valid_length_nd).wait_to_read(),
repeat=3,
dryrun=5,
min_repeat_ms=1000)
print(f"MXNet latency for batch {batch} and seq length {seq_length}: {np.mean(res):.2f} ms")
######################################
# Optimize the BERT model using TVM
######################################
# First, Convert the MXNet model into TVM Relay format
shape_dict = {
'data0': (batch, seq_length),
'data1': (batch,)
}
mod, params = relay.frontend.from_mxnet(model, shape_dict)
# Compile the imported model
target = "llvm -mcpu=skylake-avx512 -libs=cblas"
with relay.build_config(opt_level=3, required_pass=["FastMath"]):
graph, lib, cparams = relay.build(mod, target, params=params)
# Create the executor and set the parameters and inputs
ctx = tvm.cpu()
rt = runtime.create(graph, lib, ctx)
rt.set_input(**cparams)
rt.set_input(data0=inputs, data1=valid_length)
# Run the executor and validate the correctness
rt.run()
out = rt.get_output(0)
tvm.testing.assert_allclose(out.asnumpy(), mx_out.asnumpy(), rtol=1e-3, atol=1e-3)
# Benchmark the TVM latency
ftimer = rt.module.time_evaluator("run", ctx, repeat=3, min_repeat_ms=1000)
prof_res = np.array(ftimer().results) * 1000
print(f"TVM latency for batch {batch} and seq length {seq_length}: {np.mean(prof_res):.2f} ms")
@652994331
Copy link

Hi , i followed your code, but i have a problem here:
MXNet latency for batch 1 and seq length 128: 112.44 ms
Traceback (most recent call last):

File "test_bert.py", line 86, in
mod, params = relay.frontend.from_mxnet(model, shape_dict)

File "/opt/cephfs1/asr/users/qizhou.huang/.local/lib/python3.6/site-packages/tvm-0.6.0-py3.6-linux-x86_64.egg/tvm/relay/frontend/mxnet.py", line 1427, in from_mxnet
func = _from_mxnet_impl(sym, shape, dtype, mod)

File "/opt/cephfs1/asr/users/qizhou.huang/.local/lib/python3.6/site-packages/tvm-0.6.0-py3.6-linux-x86_64.egg/tvm/relay/frontend/mxnet.py", line 1340, in _from_mxnet_impl
'Operator {} is not supported in frontend MXNet.'.format(op_name))

tvm.error.OpNotImplemented: Operator _contrib_arange_like is not supported in frontend MXNet.

My tvm version is 0.6.1 and i just pip install mxnet(1.6.0). could you please help me out, thank you

@icemelon
Copy link
Author

icemelon commented Jul 27, 2020 via email

@652994331
Copy link

thanks so much Haichen, that problem's solved with the lastest version tvm. However, the test still had a problem,
Cannot find config for target=llvm -keys=cpu -libs=cblas -mcpu=skylake-avx512, workload=('dense_cblas.x86', ('TENSOR', (1, 768), 'float32'), ('TENSOR', (2, 768), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
[07:37:14] /opt/cephfs1/asr/users/qizhou.huang/qizhou/PycharmProjects/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement required_alignment=128, provided_alignment=8
the after a while, coredump and exit, without other error messages. could you please help me about this?

@tiandiao123
Copy link

thanks so much Haichen, that problem's solved with the lastest version tvm. However, the test still had a problem,
Cannot find config for target=llvm -keys=cpu -libs=cblas -mcpu=skylake-avx512, workload=('dense_cblas.x86', ('TENSOR', (1, 768), 'float32'), ('TENSOR', (2, 768), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
[07:37:14] /opt/cephfs1/asr/users/qizhou.huang/qizhou/PycharmProjects/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement required_alignment=128, provided_alignment=8
the after a while, coredump and exit, without other error messages. could you please help me about this?

I think it is just a warning message which means you didn't tune this op, but since you are using external libs, this warning message could be ignored in this case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment