Created March 16, 2019 15:58
KITTI stereo disparity 3 pixel error, corrected
* [PMSNet code](
* [type converter](
MATLAB code:
% disp_error.m
function d_err = disp_error (D_gt,D_est,tau)
E = abs(D_gt-D_est);
n_err = length(find(D_gt>0 & E>tau(1) & E./abs(D_gt)>tau(2)));
n_total = length(find(D_gt>0));
d_err = n_err/n_total;
% error threshold
tau = [3 0.05];
% stereo demo
disp('Load and show disparity map ... ');
D_est = disp_read('data/disp_est.png');
D_gt = disp_read('data/disp_gt.png');
d_err = disp_error(D_gt,D_est,tau);
So, we should also include the `equal` of both comparison
FrameWork: PyTorch
import torch
import numpy as np
def np_dtype_to_torch(dtype):
type_map = {
np.dtype(np.float16): torch.HalfTensor,
np.dtype(np.float32): torch.FloatTensor,
np.dtype(np.float64): torch.DoubleTensor,
np.dtype(np.int32): torch.IntTensor,
np.dtype(np.int64): torch.LongTensor,
np.dtype(np.uint8): torch.ByteTensor,
return type_map[dtype]
def to_tensor(arg):
if isinstance(arg, np.ndarray):
return torch.from_numpy(arg).type(np_dtype_to_torch(arg.dtype))
elif isinstance(arg, (list, tuple)):
arg = np.array(arg)
return torch.from_numpy(arg).type(np_dtype_to_torch(arg.dtype))
raise ValueError("unsupported arg type.")
def calc_3pe_np(disp_pred, disp_true):
assert disp_pred.shape == disp_true.shape
assert disp_pred.dim() == 3
disp_pred_l = disp_pred.clone().type(torch.FloatTensor).cpu().numpy()
disp_true_l = disp_true.clone().type(torch.FloatTensor).cpu().numpy()
err_l = []
for disp_pred, disp_true in zip(disp_pred_l, disp_true_l):
err_l.append(calc_3pe_standalone(disp_pred, disp_true))
return np.mean(err_l), err_l
def calc_3pe_th(disp_pred, disp_true):
assert disp_pred.shape == disp_true.shape
assert disp_pred.dim() == 3
disp_pred = disp_pred.clone().type(torch.FloatTensor)
disp_true = disp_true.clone().type(torch.FloatTensor)
_index = np.argwhere(disp_true > 0)
disp_diff = disp_true.clone()
# print(disp_diff.shape)
index = _index[0][:], _index[1][:], _index[2][:]
disp_diff[index] = torch.abs(disp_true[index] - disp_pred[index])
wrong = (disp_diff[index] > 3) & (disp_diff[index] >
disp_true[index] * 0.05)
err_l = []
c, h, w = disp_true.shape
for i in range(c):
index_mask = (_index[0] == i)
wrong_i = index_mask & wrong
wrong_i_n = np.count_nonzero(wrong_i)
index_n = np.count_nonzero(index_mask)
# wrong_i_n = float(torch.sum(wrong_i))
# index_n = float(torch.sum(index_mask))
if index_n > 0:
err = wrong_i_n / index_n
err = 0.0
return np.mean(err_l), err_l
def calc_3pe_standalone(disp_src, disp_dst):
assert disp_src.shape == disp_dst.shape, "{}, {}".format(
disp_src.shape, disp_dst.shape)
assert len(disp_src.shape) == 2 # (N*M)
not_empty = (disp_src > 0) & (~np.isnan(disp_src)) & (disp_dst > 0) & (
disp_src_flatten = disp_src[not_empty].flatten().astype(np.float32)
disp_dst_flatten = disp_dst[not_empty].flatten().astype(np.float32)
disp_diff_l = abs(disp_src_flatten - disp_dst_flatten)
accept_3p = (disp_diff_l <= 3) | (disp_diff_l <= disp_dst_flatten * 0.05)
err_3p = 1 - np.count_nonzero(accept_3p) / len(disp_diff_l)
return err_3p
if __name__ == '__main__':
# bz, h, w = (1, 2, 5)
bz, h, w = (20, 384, 1248)
# est_disp = np.random.randint(low=10, high=200, size= (bz, h, w), dtype=np.uint8)
# gt_disp = np.random.randint(low=10, high=200, size= (bz, h, w), dtype=np.uint8)
est_disp = np.random.randint(low=100, high=2000, size=(bz, h, w)) / 10.0
gt_disp = np.random.randint(low=100, high=2000, size=(bz, h, w)) / 10.0
err_s = 0
for i in range(len(est_disp)):
est = est_disp[i]
gt = gt_disp[i]
err_3p = calc_3pe_standalone(disp_src=est, disp_dst=gt) # gt as dst
# print('{:.4f}'.format(err_3p))
err_s += err_3p
err_s = err_s / len(est_disp)
print("STANDALONE: {}".format(err_s))
tensor1 = to_tensor(est_disp)
tensor2 = to_tensor(gt_disp)
err_np, err_l = calc_3pe_np(tensor1, tensor2)
err_th, err_l_th = calc_3pe_th(tensor1, tensor2)
assert np.isclose(err_s, err_np)
assert np.isclose(err_s, err_th)
def run1():
calc_3pe_np(tensor1, tensor2)
def run2():
calc_3pe_th(tensor1, tensor2)
import timeit
t = timeit.timeit("run1()", setup="from __main__ import run1", number=100)
t = timeit.timeit("run2()", setup="from __main__ import run2", number=100)
