Last active
February 11, 2022 20:10
-
-
Save schegde/d2b252c2544841c7fd3f77f0b6a694e7 to your computer and use it in GitHub Desktop.
Conv2d stride comparison with Polygraphy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This script compares conv2d operation between PyTorch and TensorRT for specified number of runs: NUM_RUNS, | |
# for the specified STRIDE, and prints the results | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import time | |
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx | |
from polygraphy.backend.trt import CreateConfig as CreateTrtConfig, EngineFromNetwork, NetworkFromOnnxPath, TrtRunner | |
from polygraphy.comparator import Comparator | |
from polygraphy.common import TensorMetadata | |
dtype_t = torch.float16 | |
dtype_pyt = np.float16 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
NUM_RUNS = 5 | |
STRIDE = 2 | |
# Initialize random inputs and weights for one channel. | |
input_layer = torch.randn([1,1,9,9], device=device, dtype=dtype_t) | |
weight_layer = torch.randn([1,1,3,3], device=device, dtype=dtype_t) | |
ONNX_FILE = "./conv2d.onnx" | |
# Replicate the input and weights for 64 channels. | |
input_t = input_layer.repeat(1,64,1,1) | |
weight_t = weight_layer.repeat(1,64,1,1) | |
conv = nn.Conv2d(64, 1, 3, stride=STRIDE, bias=False, padding=1, device=device, dtype=dtype_t) | |
with torch.no_grad(): | |
conv.weight.copy_(weight_t) | |
conv.eval() | |
torch.onnx.export(conv, input_t, ONNX_FILE, opset_version=11, verbose=True) | |
parse_network_from_onnx = NetworkFromOnnxPath(ONNX_FILE) | |
build_onnxrt_session = SessionFromOnnx(ONNX_FILE) | |
for run in range(NUM_RUNS): | |
print("************************ Run:" + str(run)) | |
# Use FP16 only for comparison | |
create_trt_config = CreateTrtConfig(max_workspace_size=200000000, fp16=True) | |
build_engine = EngineFromNetwork(parse_network_from_onnx, config=create_trt_config) | |
runners = [ | |
TrtRunner(build_engine), | |
OnnxrtRunner(build_onnxrt_session) | |
] | |
# Runner Execution | |
results = Comparator.run(runners) | |
success = True | |
# Accuracy Comparison | |
success &= bool(Comparator.compare_accuracy(results)) | |
time.sleep(3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment