-
-
Save xapaxca/2294d905f80ed783fcacc1ebff46512e to your computer and use it in GitHub Desktop.
KBR evaluation based on Monodepth's code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from layers import disp_to_depth | |
from utils import readlines | |
from options import MonodepthOptions | |
import datasets | |
from load_kbr import load_kbr | |
cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) | |
splits_dir = os.path.join(os.path.dirname(__file__), "splits") | |
# Models which were trained with stereo supervision were trained with a nominal | |
# baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore, | |
# to convert our stereo predictions to real-world scale we multiply our depths by 5.4. | |
STEREO_SCALE_FACTOR = 5.4 | |
def compute_errors(gt, pred): | |
"""Computation of error metrics between predicted and ground truth depths | |
""" | |
thresh = np.maximum((gt / pred), (pred / gt)) | |
a1 = (thresh < 1.25 ).mean() | |
a2 = (thresh < 1.25 ** 2).mean() | |
a3 = (thresh < 1.25 ** 3).mean() | |
rmse = (gt - pred) ** 2 | |
rmse = np.sqrt(rmse.mean()) | |
rmse_log = (np.log(gt) - np.log(pred)) ** 2 | |
rmse_log = np.sqrt(rmse_log.mean()) | |
abs_rel = np.mean(np.abs(gt - pred) / gt) | |
sq_rel = np.mean(((gt - pred) ** 2) / gt) | |
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 | |
def batch_post_process_disparity(l_disp, r_disp): | |
"""Apply the disparity post-processing method as introduced in Monodepthv1 | |
""" | |
_, h, w = l_disp.shape | |
m_disp = 0.5 * (l_disp + r_disp) | |
l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) | |
l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] | |
r_mask = l_mask[:, :, ::-1] | |
return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp | |
def evaluate(opt): | |
"""Evaluates a pretrained model using a specified test set | |
""" | |
MIN_DEPTH = 1e-3 | |
MAX_DEPTH = 80 | |
assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \ | |
"Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo" | |
if opt.ext_disp_to_eval is None: | |
opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) | |
assert os.path.isdir(opt.load_weights_folder), \ | |
"Cannot find a folder at {}".format(opt.load_weights_folder) | |
print("-> Loading weights from {}".format(opt.load_weights_folder)) | |
filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) | |
dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, | |
192, 640, | |
[0], 4, is_train=False) | |
dataloader = DataLoader(dataset, 16, shuffle=False, num_workers=opt.num_workers, | |
pin_memory=True, drop_last=False) | |
kbr_ckpt_path = os.path.join(opt.load_weights_folder, "kbr.ckpt") | |
encoder, depth_decoder = load_kbr(kbr_ckpt_path) | |
encoder.cuda() | |
encoder.eval() | |
depth_decoder.cuda() | |
depth_decoder.eval() | |
pred_disps = [] | |
print("-> Computing predictions with size {}x{}".format(640, 192)) | |
with torch.no_grad(): | |
for data in dataloader: | |
input_color = data[("color", 0, 0)].cuda() | |
if opt.post_process: | |
# Post-processed results require each image to have two forward passes | |
input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) | |
output = depth_decoder(encoder(input_color)) | |
pred_disp = output[0] | |
# pred_disp, _ = disp_to_depth(pred_disp, opt.min_depth, opt.max_depth) | |
pred_disp = pred_disp.cpu()[:, 0].numpy() | |
if opt.post_process: | |
N = pred_disp.shape[0] // 2 | |
pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) | |
pred_disps.append(pred_disp) | |
pred_disps = np.concatenate(pred_disps) | |
else: | |
# Load predictions from file | |
print("-> Loading predictions from {}".format(opt.ext_disp_to_eval)) | |
pred_disps = np.load(opt.ext_disp_to_eval) | |
if opt.eval_eigen_to_benchmark: | |
eigen_to_benchmark_ids = np.load( | |
os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy")) | |
pred_disps = pred_disps[eigen_to_benchmark_ids] | |
if opt.save_pred_disps: | |
output_path = os.path.join( | |
opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split)) | |
print("-> Saving predicted disparities to ", output_path) | |
np.save(output_path, pred_disps) | |
if opt.no_eval: | |
print("-> Evaluation disabled. Done.") | |
quit() | |
elif opt.eval_split == 'benchmark': | |
save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions") | |
print("-> Saving out benchmark predictions to {}".format(save_dir)) | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
for idx in range(len(pred_disps)): | |
disp_resized = cv2.resize(pred_disps[idx], (1216, 352)) | |
depth = STEREO_SCALE_FACTOR / disp_resized | |
depth = np.clip(depth, 0, 80) | |
depth = np.uint16(depth * 256) | |
save_path = os.path.join(save_dir, "{:010d}.png".format(idx)) | |
cv2.imwrite(save_path, depth) | |
print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.") | |
quit() | |
gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") | |
gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"] | |
print("-> Evaluating") | |
if opt.eval_stereo: | |
print(" Stereo evaluation - " | |
"disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR)) | |
opt.disable_median_scaling = True | |
opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR | |
else: | |
print(" Mono evaluation - using median scaling") | |
errors = [] | |
ratios = [] | |
for i in range(pred_disps.shape[0]): | |
gt_depth = gt_depths[i] | |
gt_height, gt_width = gt_depth.shape[:2] | |
pred_disp = pred_disps[i] | |
pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) | |
pred_depth = 1 / pred_disp | |
if opt.eval_split == "eigen": | |
mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) | |
crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, | |
0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) | |
crop_mask = np.zeros(mask.shape) | |
crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 | |
mask = np.logical_and(mask, crop_mask) | |
else: | |
mask = gt_depth > 0 | |
pred_depth = pred_depth[mask] | |
gt_depth = gt_depth[mask] | |
pred_depth *= opt.pred_depth_scale_factor | |
if not opt.disable_median_scaling: | |
ratio = np.median(gt_depth) / np.median(pred_depth) | |
ratios.append(ratio) | |
pred_depth *= ratio | |
pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH | |
pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH | |
errors.append(compute_errors(gt_depth, pred_depth)) | |
if not opt.disable_median_scaling: | |
ratios = np.array(ratios) | |
med = np.median(ratios) | |
print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) | |
mean_errors = np.array(errors).mean(0) | |
print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) | |
print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") | |
print("\n-> Done!") | |
if __name__ == "__main__": | |
options = MonodepthOptions() | |
evaluate(options.parse()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
import timm | |
ACT = { | |
'sigmoid': nn.Sigmoid(), | |
'relu': nn.ReLU(inplace=True), | |
'none': nn.Identity(), | |
None: nn.Identity(), | |
} | |
def conv3x3(in_ch: int, out_ch: int, bias: bool = True) -> nn.Conv2d: | |
"""Conv layer with 3x3 kernel and `reflect` padding.""" | |
return nn.Conv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1, padding_mode='reflect', bias=bias) | |
def conv_block(in_ch: int, out_ch: int) -> nn.Module: | |
"""Layer to perform a convolution followed by ELU.""" | |
return nn.Sequential(OrderedDict({ | |
'conv': conv3x3(in_ch, out_ch), | |
'act': nn.ELU(inplace=True), | |
})) | |
class MonodepthDecoder(nn.Module): | |
"""From Monodepth(2) (https://arxiv.org/abs/1806.01260) | |
Generic convolutional decoder incorporating multi-scale predictions and skip connections. | |
:param num_ch_enc: (list[int]) List of channels per encoder stage. | |
:param enc_sc: (list[int]) List of downsampling factor per encoder stage. | |
:param upsample_mode: (str) Torch upsampling mode. {'nearest', 'bilinear'...} | |
:param use_skip: (bool) If `True`, add skip connections from corresponding encoder stage. | |
:param out_sc: (list[int]) List of multi-scale output downsampling factor as 2**s. | |
:param out_ch: (int) Number of output channels. | |
:param out_act: (str) Activation to apply to each output stage. | |
""" | |
def __init__(self, | |
num_ch_enc: list[int], | |
enc_sc: list[int], | |
upsample_mode: str = 'nearest', | |
use_skip: bool = True, | |
out_sc: list[int] = (0, 1, 2, 3), | |
out_ch: int = 1, | |
out_act: str = 'sigmoid'): | |
super().__init__() | |
self.num_ch_enc = num_ch_enc | |
self.enc_sc = enc_sc | |
self.upsample_mode = upsample_mode | |
self.use_skip = use_skip | |
self.out_sc = out_sc | |
self.out_ch = out_ch | |
self.out_act = out_act | |
if self.out_act not in ACT: | |
raise KeyError(f'Invalid activation key. ({self.out_act} vs. {tuple(ACT.keys())}') | |
self.act = ACT[self.out_act] | |
self.num_ch_dec = [16, 32, 64, 128, 256] | |
self.convs = OrderedDict() | |
for i in range(4, -1, -1): | |
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i+1] | |
num_ch_out = self.num_ch_dec[i] | |
self.convs[f'upconv_{i}_{0}'] = conv_block(num_ch_in, num_ch_out) | |
num_ch_in = self.num_ch_dec[i] | |
sf = 2**i # NOTE: Skip connection resolution, which is current scale upsampled x2 | |
if self.use_skip and sf in self.enc_sc: | |
idx = self.enc_sc.index(sf) | |
num_ch_in += self.num_ch_enc[idx] | |
num_ch_out = self.num_ch_dec[i] | |
self.convs[f'upconv_{i}_{1}'] = conv_block(num_ch_in, num_ch_out) | |
# Create multi-scale outputs | |
for i in self.out_sc: | |
self.convs[f'outconv_{i}'] = conv3x3(self.num_ch_dec[i], self.out_ch) | |
self.decoder = nn.ModuleList(list(self.convs.values())) | |
def forward(self, feat: list[Tensor]) -> dict[int, Tensor]: | |
out = {} | |
x = feat[-1] | |
for i in range(4, -1, -1): | |
x = self.convs[f'upconv_{i}_{0}'](x) | |
x = [F.interpolate(x, scale_factor=2, mode=self.upsample_mode)] | |
sf = 2**i | |
if self.use_skip and sf in self.enc_sc: | |
idx = self.enc_sc.index(sf) | |
x += [feat[idx]] | |
x = torch.cat(x, 1) | |
x = self.convs[f'upconv_{i}_{1}'](x) | |
if i in self.out_sc: | |
out[i] = self.act(self.convs[f'outconv_{i}'](x)) | |
return out | |
def load_weights(model, state_dict, prefix): | |
filtered_weights = {} | |
for key, value in state_dict.items(): | |
if key.startswith(prefix): | |
new_key = key.replace(prefix, '') | |
filtered_weights[new_key] = value | |
missing_keys, unexpected_keys = model.load_state_dict(filtered_weights, strict=False) | |
print(f"Loading weights with prefix '{prefix}':") | |
print(f"\tTotal number of keys: {len(filtered_weights.keys())}") | |
print(f"\tNumber of missing keys: {len(missing_keys)}") | |
print(f"\tNumber of unexpected keys: {len(unexpected_keys)}") | |
def load_kbr(ckpt_path): | |
depth_encoder = timm.create_model("convnext_base", features_only=True, pretrained=True) | |
depth_decoder = MonodepthDecoder(num_ch_enc=depth_encoder.feature_info.channels(), enc_sc=depth_encoder.feature_info.reduction()) | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
load_weights(depth_encoder, checkpoint['state_dict'], 'nets.depth.encoder.') | |
load_weights(depth_decoder, checkpoint['state_dict'], 'nets.depth.decoders.disp.') | |
return depth_encoder, depth_decoder |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment