-
-
Save Wheest/d19fda0afbdab7d55948c4a067947d85 to your computer and use it in GitHub Desktop.
Example which tries to load StyleGAN2 into TVM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import pickle | |
import torch | |
from tvm import relay | |
from tvm.relay.frontend.pytorch import PyTorchOpConverter | |
# To load the model, you need to download the model, and the code to load the model | |
# 1. download the model file: | |
# wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl | |
# 2. download the code | |
# git clone https://github.com/dvschultz/stylegan2-ada-pytorch | |
# 3. add the project to the pythonpath | |
# $ echo "export PYTHONPATH=stylegan2-ada-pytorch:\${PYTHONPATH}" > ~/.bashrc | |
# $ source ~/.bashrc | |
def get_stylegan2_full(path): | |
with open(path, "rb") as f: | |
G = pickle.load(f)["G_ema"].cuda() # torch.nn.Module, needs to be cuda | |
return G | |
def randn(inputs, input_types): | |
return relay.expr.const( | |
torch.randn( | |
size=tuple( | |
int(i.data.asnumpy()) if isinstance(i, relay.Constant) else int(i) | |
for i in inputs[0] | |
) | |
).numpy() | |
) | |
def square(inputs, input_types): | |
return relay.expr.const(torch.square(*inputs).numpy()) | |
def none(inputs, input_types): | |
return None | |
def pytorch_exporter(model): | |
test_input_datas = ( | |
torch.randn((1, 512), dtype=torch.float32).cuda(), | |
torch.randn((1), dtype=torch.float32).cuda(), | |
) | |
# generate test outputs | |
model.eval() | |
outs = model(*test_input_datas) | |
scripted_model = torch.jit.trace(model, test_input_datas).eval() | |
shape_list = [("input0", [1, 512]), ("input1", [1])] | |
mod, params = relay.frontend.from_pytorch( | |
scripted_model, | |
shape_list, | |
{ | |
"aten::randn": randn, | |
"profiler::_record_function_enter": none, | |
"profiler::_record_function_exit": none, | |
"prim::PythonOp": none, | |
"aten::square": none, | |
}, | |
) | |
print("Exported to TVM!") | |
if __name__ == "__main__": | |
model_path = "metfaces.pkl" | |
model = get_stylegan2_full(model_path) | |
pytorch_exporter(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment