Skip to content

Instantly share code, notes, and snippets.

@Wheest
Created December 15, 2021 16:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Wheest/d19fda0afbdab7d55948c4a067947d85 to your computer and use it in GitHub Desktop.
Save Wheest/d19fda0afbdab7d55948c4a067947d85 to your computer and use it in GitHub Desktop.
Example which tries to load StyleGAN2 into TVM
#!/usr/bin/env python
import pickle
import torch
from tvm import relay
from tvm.relay.frontend.pytorch import PyTorchOpConverter
# To load the model, you need to download the model, and the code to load the model
# 1. download the model file:
# wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
# 2. download the code
# git clone https://github.com/dvschultz/stylegan2-ada-pytorch
# 3. add the project to the pythonpath
# $ echo "export PYTHONPATH=stylegan2-ada-pytorch:\${PYTHONPATH}" > ~/.bashrc
# $ source ~/.bashrc
def get_stylegan2_full(path):
with open(path, "rb") as f:
G = pickle.load(f)["G_ema"].cuda() # torch.nn.Module, needs to be cuda
return G
def randn(inputs, input_types):
return relay.expr.const(
torch.randn(
size=tuple(
int(i.data.asnumpy()) if isinstance(i, relay.Constant) else int(i)
for i in inputs[0]
)
).numpy()
)
def square(inputs, input_types):
return relay.expr.const(torch.square(*inputs).numpy())
def none(inputs, input_types):
return None
def pytorch_exporter(model):
test_input_datas = (
torch.randn((1, 512), dtype=torch.float32).cuda(),
torch.randn((1), dtype=torch.float32).cuda(),
)
# generate test outputs
model.eval()
outs = model(*test_input_datas)
scripted_model = torch.jit.trace(model, test_input_datas).eval()
shape_list = [("input0", [1, 512]), ("input1", [1])]
mod, params = relay.frontend.from_pytorch(
scripted_model,
shape_list,
{
"aten::randn": randn,
"profiler::_record_function_enter": none,
"profiler::_record_function_exit": none,
"prim::PythonOp": none,
"aten::square": none,
},
)
print("Exported to TVM!")
if __name__ == "__main__":
model_path = "metfaces.pkl"
model = get_stylegan2_full(model_path)
pytorch_exporter(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment