Skip to content

Instantly share code, notes, and snippets.

@schegde
Last active February 11, 2022 20:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schegde/d2b252c2544841c7fd3f77f0b6a694e7 to your computer and use it in GitHub Desktop.
Save schegde/d2b252c2544841c7fd3f77f0b6a694e7 to your computer and use it in GitHub Desktop.
Conv2d stride comparison with Polygraphy
# This script compares conv2d operation between PyTorch and TensorRT for specified number of runs: NUM_RUNS,
# for the specified STRIDE, and prints the results
import torch
import torch.nn as nn
import numpy as np
import time
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
from polygraphy.backend.trt import CreateConfig as CreateTrtConfig, EngineFromNetwork, NetworkFromOnnxPath, TrtRunner
from polygraphy.comparator import Comparator
from polygraphy.common import TensorMetadata
dtype_t = torch.float16
dtype_pyt = np.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_RUNS = 5
STRIDE = 2
# Initialize random inputs and weights for one channel.
input_layer = torch.randn([1,1,9,9], device=device, dtype=dtype_t)
weight_layer = torch.randn([1,1,3,3], device=device, dtype=dtype_t)
ONNX_FILE = "./conv2d.onnx"
# Replicate the input and weights for 64 channels.
input_t = input_layer.repeat(1,64,1,1)
weight_t = weight_layer.repeat(1,64,1,1)
conv = nn.Conv2d(64, 1, 3, stride=STRIDE, bias=False, padding=1, device=device, dtype=dtype_t)
with torch.no_grad():
conv.weight.copy_(weight_t)
conv.eval()
torch.onnx.export(conv, input_t, ONNX_FILE, opset_version=11, verbose=True)
parse_network_from_onnx = NetworkFromOnnxPath(ONNX_FILE)
build_onnxrt_session = SessionFromOnnx(ONNX_FILE)
for run in range(NUM_RUNS):
print("************************ Run:" + str(run))
# Use FP16 only for comparison
create_trt_config = CreateTrtConfig(max_workspace_size=200000000, fp16=True)
build_engine = EngineFromNetwork(parse_network_from_onnx, config=create_trt_config)
runners = [
TrtRunner(build_engine),
OnnxrtRunner(build_onnxrt_session)
]
# Runner Execution
results = Comparator.run(runners)
success = True
# Accuracy Comparison
success &= bool(Comparator.compare_accuracy(results))
time.sleep(3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment