Skip to content

Instantly share code, notes, and snippets.

@Wheest
Created August 11, 2022 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Wheest/a05837c1ae257d177ef5431ef5ac4768 to your computer and use it in GitHub Desktop.
Save Wheest/a05837c1ae257d177ef5431ef5ac4768 to your computer and use it in GitHub Desktop.
TVM sparse DNN example with auto-tuning
#!/usr/bin/env python
import tvm
import argparse
from tvm import te, relay, auto_scheduler
from tvm.contrib.download import download_testdata
from tvm.contrib import graph_executor as graph_runtime
from tvm.contrib import graph_executor
from tvm.relay import data_dep_optimization as ddo
import logging
import os
import timeit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torchvision import transforms
import torch.nn.utils.prune as prune
from typing import Optional
batch_size = 1
target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"
dev = tvm.device(target)
np.random.seed(42)
INPUT_SIZE = 128
class WeeNet(nn.Module):
def __init__(self):
super(WeeNet, self).__init__()
self.layer1 = nn.Conv2d(
3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), bias=False
)
def forward(self, x):
out = F.relu(self.layer1(x))
return out
def sparsify_model(mod, params):
# convert to sparse if relevant
bs_r = 1
mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params)
mod, params = ddo.bsr_conv2d.convert(
mod,
params,
blocksize=(bs_r, 1),
sparsity_threshold=0.05,
layout="NCHW",
)
return tvm.IRModule.from_expr(mod), params
def get_model():
model = WeeNet()
input_shape = (batch_size, 3, INPUT_SIZE, INPUT_SIZE)
input_data = torch.randint(100, input_shape, dtype=torch.float32)
for module in [model.layer1]:
prune.random_unstructured(module, name="weight", amount=0.7)
prune.remove(module, "weight")
scripted_model = torch.jit.trace(model, input_data).eval()
y = scripted_model(input_data)
mod, params = relay.frontend.from_pytorch(scripted_model, [("data", input_shape)])
mod, params = sparsify_model(mod, params)
print(mod)
x = input_data.numpy()
y = y.detach().numpy()
return mod, params, x, y
def auto_schedule(trials=200):
mod, params, x, y_target = get_model()
log_file = "/tmp/log_file.json"
if os.path.exists(log_file):
os.remove(log_file)
tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target)
if len(tasks) == 0:
print("Hey we do not have any tasks...")
r_tsks, r_weights = [], []
for idx, task in enumerate(tasks):
print(
"========== Task %d (workload key: %s) =========="
% (idx, task.workload_key)
)
print(task.desc)
print(task.compute_dag)
if "conv2d" not in task.desc:
r_tsks.append(task)
r_weights.append(task_weights[idx])
for t, w in zip(r_tsks, r_weights):
tasks.remove(t)
task_weights.remove(w)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=trials, # change this to 20000 to achieve the best performance
runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True),
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
from tvm.topi.sparse.utils import sparse_sketch_rules
search_policy = [
auto_scheduler.SketchPolicy(
task,
program_cost_model=auto_scheduler.XGBModel(),
init_search_callbacks=sparse_sketch_rules(),
)
for task in tasks
]
tuner.tune(tune_option, search_policy=search_policy)
sch, args = tasks[0].apply_best(log_file)
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
lib = relay.build(mod, target=target, params=params)
module = graph_executor.GraphModule(lib["default"](dev))
dtype = "float32"
data = x
data_tvm = tvm.nd.array((data).astype(dtype))
module.set_input("data", data_tvm)
print("Ansor:")
print(module.benchmark(dev, repeat=3, min_repeat_ms=500))
print()
def regular():
mod, params, x, y_target = get_model()
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
dtype = "float32"
data_tvm = tvm.nd.array((x).astype(dtype))
m.set_input("data", data_tvm)
m.run()
y = m.get_output(0).numpy()
print("Untuned:")
print(m.benchmark(dev, repeat=3, min_repeat_ms=500))
print()
def main(args):
if args.mode == "regular":
regular()
else:
auto_schedule(100)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a model")
parser.add_argument(
"mode",
default="regular",
choices=["regular", "ansor"],
)
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment