Skip to content

Instantly share code, notes, and snippets.

View jaemin93's full-sized avatar
🎯
Focusing

Jaemin Jung jaemin93

🎯
Focusing
  • Seoul
View GitHub Profile
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
import torch
from classification_model import Net
from torch.onnx import OperatorExportTypes
model = Net()
model.eval()
x = torch.zeros([1, 1, 28, 28])
print(x.shape)
torch.onnx.export(model, x, "test.onnx", verbose=True, operator_export_type=OperatorExportTypes.ONNX)
import torch
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import os
import numpy as np
from tqdm import tqdm
TRT_LOGGER = trt.Logger()
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
def get_engine(onnx_file_path, engine_file_path=""):
def build_engine():
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 28
builder.max_batch_size = 1
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
print('ERROR: Failed to parse the ONNX file')
def get_engine(onnx_file_path, engine_file_path=""):
def build_engine():
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 28
builder.max_batch_size = 1
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
print('ERROR: Failed to parse the ONNX file')
def get_engine(onnx_file_path, engine_file_path=""):
def build_engine():
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 28
builder.max_batch_size = 1
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
print('ERROR: Failed to parse the ONNX file')