Skip to content

Instantly share code, notes, and snippets.

@Chen-Xieyuanli
Created January 20, 2020 12:27
Show Gist options
  • Save Chen-Xieyuanli/980cc0a2c9b664b9cc279b19d61aa898 to your computer and use it in GitHub Desktop.
Save Chen-Xieyuanli/980cc0a2c9b664b9cc279b19d61aa898 to your computer and use it in GitHub Desktop.
make onnx for rangenet++
#!/usr/bin/env python3
import argparse
import os
import time
from parser import SemanticKittiLoader
import numpy as np
import torch
import auxiliary.tools as tools
from knn import nearest_1d
# fastest possible
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def infer(loader,
model,
sequences_dir,
map_fn,
postp=None,
probs=False,
half=False):
# empty the cache to infer in high res
if torch.cuda.is_available():
torch.cuda.empty_cache()
if half:
print("Inferring these sequences in HALF precision")
with torch.no_grad():
for scan, mask, labels, seq, name, p_x, p_y, p_rang, u_rang, p_xyz, u_xyz, p_signal, u_signal in loader:
# print eval scan
seq = seq[0]
name = name[0]
print("seq", seq, "scan", name)
if half:
scan = scan.half() # this one is normalized for cnn
p_rang = p_rang.half()
u_rang = u_rang.half()
p_xyz = p_xyz.half()
u_xyz = u_xyz.half()
p_signal = p_signal.half()
u_signal = u_signal.half()
# convert scan to cuda
if torch.cuda.is_available():
scan = scan.cuda() # this one is normalized for cnn
p_x = p_x.cuda()
p_y = p_y.cuda()
p_rang = p_rang.cuda()
u_rang = u_rang.cuda()
p_xyz = p_xyz.cuda()
u_xyz = u_xyz.cuda()
p_signal = p_signal.cuda()
u_signal = u_signal.cuda()
# size of scan
B, C, H, W = scan.shape
# compute output
start = time.time()
output = model(scan)
if torch.cuda.is_available():
torch.cuda.synchronize()
print("time infer: ", time.time() - start)
# get classes for NN search (don't ask)
nclasses = output.shape[1]
# if I want probabilities, save in matrix of Npoints x Nclasses
if probs:
# get points
pr = output[0][:, p_y, p_x].squeeze().t_()
# get probabilities and save in file
pr_np = pr.cpu().numpy()
pr_np = pr_np.reshape((-1, nclasses)).astype(np.float32)
# save pr
path = os.path.join(sequences_dir, seq, "predictions", name)
pr_np.tofile(path)
else:
# get argmax for range image representation
p_argmax = output.argmax(dim=1)
if torch.cuda.is_available():
torch.cuda.synchronize()
print("time infer + argmax: ", time.time() - start)
# get point's labels
u_argmax = p_argmax[0][p_y, p_x]
# According to the postprocessing directive we will execute differnt
# types of post-processing algorithms
if postp:
u_argmax = nearest_1d(
postp,
p_rang=p_rang,
u_rang=u_rang,
p_xyz=p_xyz,
u_xyz=u_xyz,
p_signal=p_signal,
u_signal=u_signal,
p_argmax=p_argmax,
p_x=p_x,
p_y=p_y,
nclasses=nclasses)
# smooth label with knn (if needed)
if torch.cuda.is_available():
torch.cuda.synchronize()
print("time infer + argmax + nearest: ", time.time() - start)
# get the first scan in batch and project scan
pred_np = u_argmax.cpu().numpy()
pred_np = pred_np.reshape((-1)).astype(np.int32)
# map to original label
pred_np = map_fn(pred_np)
# save scan
path = os.path.join(sequences_dir, seq, "predictions", name)
pred_np.tofile(path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
'--dataset',
'-d',
type=str,
required=True,
help='Dataset to visualize. No Default')
parser.add_argument(
'--logdir',
'-l',
type=str,
required=False,
default="/tmp/infer_log",
help='Log directory. Defaults to %(default)s')
parser.add_argument(
'--pretrained',
'-p',
type=str,
required=True,
default=None,
help='Pretrained weights directory. Defaults to %(default)s')
parser.add_argument(
'--store_probs',
dest='store_probs',
default=False,
action='store_true',
help='Output softmax mode. Defaults to %(default)s')
parser.add_argument(
'--test_only',
dest='test_only',
default=False,
action='store_true',
help='Only infer test set. Defaults to %(default)s')
parser.add_argument(
'--half',
dest='half',
default=False,
action='store_true',
help='Infer in 16bit. Defaults to %(default)s')
parser.add_argument(
'--postp',
'-pp',
default=None,
required=False,
nargs='+',
help='Post-Processing algorithm:\n'
'type[string] : Type of post-proccesing algorithm \n'
'knn[int] : Number for nearest neighbor search in range image \n'
'search[int] : Size for kernel for the neighbor search in range image \n'
'sigma[float] : Sigma for nn search weight \n'
'cutoff[float]: Cutoff for nn search (meters for range and euclidean'
'distance degrees for beta)\n')
FLAGS, _ = parser.parse_known_args()
# print summary of what we will do
tools.print_parser_args(FLAGS)
# Validate that all the post-procesisng hyperparameters exits in case we want
# to do post-processing
tools.validate_postp_args(FLAGS.postp, print_args=True)
# does model folder exist?
tools.validate_existing_dir(FLAGS.pretrained, "Using model from ")
# open data config file
datacfg = os.path.join(FLAGS.pretrained, "datacfg.yaml")
DATA = tools.read_from_yaml(datacfg)
# open arch config file
traincfg = os.path.join(FLAGS.pretrained, "traincfg.yaml")
TRAIN = tools.read_from_yaml(traincfg)
# create log folder and copy all files to it
tools.create_predictions_dir(FLAGS, DATA)
# important variables
dataset = FLAGS.dataset
logdir = FLAGS.logdir
pretrained = FLAGS.pretrained
sequences_dir = os.path.join(FLAGS.logdir, "sequences")
# architecture import
if TRAIN["arch"] == "squeezeseg":
from baselines.range.squeezeseg import SqueezeSeg as Net
elif TRAIN["arch"] == "darknetseg":
from baselines.range.darknetseg import DarknetSeg as Net
elif TRAIN["arch"] == "darknetseg_syncbn":
from baselines.range.darknetseg_syncbn import DarknetSegSyncBN as Net
else:
print("Architecture {} not defined".format(TRAIN["arch"]))
quit()
# data parser
semantic_parser = SemanticKittiLoader(
root=dataset,
train_sequences=DATA["split"]["train"],
valid_sequences=DATA["split"]["valid"],
labels=DATA["labels"],
color_map=DATA["color_map"],
learning_map=DATA["learning_map"],
learning_map_inv=DATA["learning_map_inv"],
crop=[TRAIN["params"]["crop_h"], TRAIN["params"]["crop_w"]],
batch_size=1,
workers=TRAIN["workers"],
gt=False)
# architecture definition
model = Net(
classes=semantic_parser.get_n_classes(),
params=TRAIN["params"],
softmax=FLAGS.store_probs) # only do softmax if probs are needed
# switch to evaluate mode
model.eval()
for param in model.parameters():
param.requires_grad = False
# get weights?
try:
pretrained_dict = torch.load(
os.path.join(pretrained, "weights.pth"),
map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=True)
print("Successfully loaded model weights")
except Exception as e:
print("Couldn't load network, using random weights. Error: ", e)
# report model parameters
weights_total = sum(p.numel() for p in model.parameters())
weights_grad = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters: ", weights_total)
print("Total number of parameters requires_grad: ", weights_grad)
# GPU?
gpu = False
multi_gpu = False
model_single = model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Training in device: ", device)
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
# pcl data has different sizes and benchmark is slow
gpu = True
torch.backends.cudnn.benchmark = False
model.cuda()
if FLAGS.half:
print("Inferring in half precision!!!!")
model = model.half()
# do the train set
if not FLAGS.test_only:
infer(
loader=semantic_parser.get_train_set(),
model=model,
sequences_dir=sequences_dir,
map_fn=semantic_parser.to_original,
probs=FLAGS.store_probs,
postp=FLAGS.postp,
half=FLAGS.half)
# do the validation set
infer(
loader=semantic_parser.get_valid_set(),
model=model,
sequences_dir=sequences_dir,
map_fn=semantic_parser.to_original,
probs=FLAGS.store_probs,
postp=FLAGS.postp,
half=FLAGS.half)
# do the test set
del semantic_parser
semantic_parser = SemanticKittiLoader(
root=dataset,
train_sequences=DATA["split"]["train"],
valid_sequences=DATA["split"]["test"],
labels=DATA["labels"],
color_map=DATA["color_map"],
learning_map=DATA["learning_map"],
learning_map_inv=DATA["learning_map_inv"],
crop=[TRAIN["params"]["crop_h"], TRAIN["params"]["crop_w"]],
batch_size=1,
workers=TRAIN["workers"],
gt=False)
# do the validation set
infer(
loader=semantic_parser.get_valid_set(),
model=model,
sequences_dir=sequences_dir,
map_fn=semantic_parser.to_original,
probs=FLAGS.store_probs,
postp=FLAGS.postp,
half=FLAGS.half)
del semantic_parser
print('Finished Evaluation')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment