Skip to content

Instantly share code, notes, and snippets.

@xapaxca
Last active March 3, 2024 20:02
Show Gist options
  • Save xapaxca/2294d905f80ed783fcacc1ebff46512e to your computer and use it in GitHub Desktop.
Save xapaxca/2294d905f80ed783fcacc1ebff46512e to your computer and use it in GitHub Desktop.
KBR evaluation based on Monodepth's code
from __future__ import absolute_import, division, print_function
import os
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from layers import disp_to_depth
from utils import readlines
from options import MonodepthOptions
import datasets
from load_kbr import load_kbr
cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1)
splits_dir = os.path.join(os.path.dirname(__file__), "splits")
# Models which were trained with stereo supervision were trained with a nominal
# baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore,
# to convert our stereo predictions to real-world scale we multiply our depths by 5.4.
STEREO_SCALE_FACTOR = 5.4
def compute_errors(gt, pred):
"""Computation of error metrics between predicted and ground truth depths
"""
thresh = np.maximum((gt / pred), (pred / gt))
a1 = (thresh < 1.25 ).mean()
a2 = (thresh < 1.25 ** 2).mean()
a3 = (thresh < 1.25 ** 3).mean()
rmse = (gt - pred) ** 2
rmse = np.sqrt(rmse.mean())
rmse_log = (np.log(gt) - np.log(pred)) ** 2
rmse_log = np.sqrt(rmse_log.mean())
abs_rel = np.mean(np.abs(gt - pred) / gt)
sq_rel = np.mean(((gt - pred) ** 2) / gt)
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
def batch_post_process_disparity(l_disp, r_disp):
"""Apply the disparity post-processing method as introduced in Monodepthv1
"""
_, h, w = l_disp.shape
m_disp = 0.5 * (l_disp + r_disp)
l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...]
r_mask = l_mask[:, :, ::-1]
return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp
def evaluate(opt):
"""Evaluates a pretrained model using a specified test set
"""
MIN_DEPTH = 1e-3
MAX_DEPTH = 80
assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \
"Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo"
if opt.ext_disp_to_eval is None:
opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)
assert os.path.isdir(opt.load_weights_folder), \
"Cannot find a folder at {}".format(opt.load_weights_folder)
print("-> Loading weights from {}".format(opt.load_weights_folder))
filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt"))
dataset = datasets.KITTIRAWDataset(opt.data_path, filenames,
192, 640,
[0], 4, is_train=False)
dataloader = DataLoader(dataset, 16, shuffle=False, num_workers=opt.num_workers,
pin_memory=True, drop_last=False)
kbr_ckpt_path = os.path.join(opt.load_weights_folder, "kbr.ckpt")
encoder, depth_decoder = load_kbr(kbr_ckpt_path)
encoder.cuda()
encoder.eval()
depth_decoder.cuda()
depth_decoder.eval()
pred_disps = []
print("-> Computing predictions with size {}x{}".format(640, 192))
with torch.no_grad():
for data in dataloader:
input_color = data[("color", 0, 0)].cuda()
if opt.post_process:
# Post-processed results require each image to have two forward passes
input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0)
output = depth_decoder(encoder(input_color))
pred_disp = output[0]
# pred_disp, _ = disp_to_depth(pred_disp, opt.min_depth, opt.max_depth)
pred_disp = pred_disp.cpu()[:, 0].numpy()
if opt.post_process:
N = pred_disp.shape[0] // 2
pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1])
pred_disps.append(pred_disp)
pred_disps = np.concatenate(pred_disps)
else:
# Load predictions from file
print("-> Loading predictions from {}".format(opt.ext_disp_to_eval))
pred_disps = np.load(opt.ext_disp_to_eval)
if opt.eval_eigen_to_benchmark:
eigen_to_benchmark_ids = np.load(
os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy"))
pred_disps = pred_disps[eigen_to_benchmark_ids]
if opt.save_pred_disps:
output_path = os.path.join(
opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split))
print("-> Saving predicted disparities to ", output_path)
np.save(output_path, pred_disps)
if opt.no_eval:
print("-> Evaluation disabled. Done.")
quit()
elif opt.eval_split == 'benchmark':
save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions")
print("-> Saving out benchmark predictions to {}".format(save_dir))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for idx in range(len(pred_disps)):
disp_resized = cv2.resize(pred_disps[idx], (1216, 352))
depth = STEREO_SCALE_FACTOR / disp_resized
depth = np.clip(depth, 0, 80)
depth = np.uint16(depth * 256)
save_path = os.path.join(save_dir, "{:010d}.png".format(idx))
cv2.imwrite(save_path, depth)
print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.")
quit()
gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"]
print("-> Evaluating")
if opt.eval_stereo:
print(" Stereo evaluation - "
"disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR))
opt.disable_median_scaling = True
opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR
else:
print(" Mono evaluation - using median scaling")
errors = []
ratios = []
for i in range(pred_disps.shape[0]):
gt_depth = gt_depths[i]
gt_height, gt_width = gt_depth.shape[:2]
pred_disp = pred_disps[i]
pred_disp = cv2.resize(pred_disp, (gt_width, gt_height))
pred_depth = 1 / pred_disp
if opt.eval_split == "eigen":
mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)
crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height,
0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32)
crop_mask = np.zeros(mask.shape)
crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
mask = np.logical_and(mask, crop_mask)
else:
mask = gt_depth > 0
pred_depth = pred_depth[mask]
gt_depth = gt_depth[mask]
pred_depth *= opt.pred_depth_scale_factor
if not opt.disable_median_scaling:
ratio = np.median(gt_depth) / np.median(pred_depth)
ratios.append(ratio)
pred_depth *= ratio
pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH
errors.append(compute_errors(gt_depth, pred_depth))
if not opt.disable_median_scaling:
ratios = np.array(ratios)
med = np.median(ratios)
print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med)))
mean_errors = np.array(errors).mean(0)
print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3"))
print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\")
print("\n-> Done!")
if __name__ == "__main__":
options = MonodepthOptions()
evaluate(options.parse())
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import timm
ACT = {
'sigmoid': nn.Sigmoid(),
'relu': nn.ReLU(inplace=True),
'none': nn.Identity(),
None: nn.Identity(),
}
def conv3x3(in_ch: int, out_ch: int, bias: bool = True) -> nn.Conv2d:
"""Conv layer with 3x3 kernel and `reflect` padding."""
return nn.Conv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1, padding_mode='reflect', bias=bias)
def conv_block(in_ch: int, out_ch: int) -> nn.Module:
"""Layer to perform a convolution followed by ELU."""
return nn.Sequential(OrderedDict({
'conv': conv3x3(in_ch, out_ch),
'act': nn.ELU(inplace=True),
}))
class MonodepthDecoder(nn.Module):
"""From Monodepth(2) (https://arxiv.org/abs/1806.01260)
Generic convolutional decoder incorporating multi-scale predictions and skip connections.
:param num_ch_enc: (list[int]) List of channels per encoder stage.
:param enc_sc: (list[int]) List of downsampling factor per encoder stage.
:param upsample_mode: (str) Torch upsampling mode. {'nearest', 'bilinear'...}
:param use_skip: (bool) If `True`, add skip connections from corresponding encoder stage.
:param out_sc: (list[int]) List of multi-scale output downsampling factor as 2**s.
:param out_ch: (int) Number of output channels.
:param out_act: (str) Activation to apply to each output stage.
"""
def __init__(self,
num_ch_enc: list[int],
enc_sc: list[int],
upsample_mode: str = 'nearest',
use_skip: bool = True,
out_sc: list[int] = (0, 1, 2, 3),
out_ch: int = 1,
out_act: str = 'sigmoid'):
super().__init__()
self.num_ch_enc = num_ch_enc
self.enc_sc = enc_sc
self.upsample_mode = upsample_mode
self.use_skip = use_skip
self.out_sc = out_sc
self.out_ch = out_ch
self.out_act = out_act
if self.out_act not in ACT:
raise KeyError(f'Invalid activation key. ({self.out_act} vs. {tuple(ACT.keys())}')
self.act = ACT[self.out_act]
self.num_ch_dec = [16, 32, 64, 128, 256]
self.convs = OrderedDict()
for i in range(4, -1, -1):
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i+1]
num_ch_out = self.num_ch_dec[i]
self.convs[f'upconv_{i}_{0}'] = conv_block(num_ch_in, num_ch_out)
num_ch_in = self.num_ch_dec[i]
sf = 2**i # NOTE: Skip connection resolution, which is current scale upsampled x2
if self.use_skip and sf in self.enc_sc:
idx = self.enc_sc.index(sf)
num_ch_in += self.num_ch_enc[idx]
num_ch_out = self.num_ch_dec[i]
self.convs[f'upconv_{i}_{1}'] = conv_block(num_ch_in, num_ch_out)
# Create multi-scale outputs
for i in self.out_sc:
self.convs[f'outconv_{i}'] = conv3x3(self.num_ch_dec[i], self.out_ch)
self.decoder = nn.ModuleList(list(self.convs.values()))
def forward(self, feat: list[Tensor]) -> dict[int, Tensor]:
out = {}
x = feat[-1]
for i in range(4, -1, -1):
x = self.convs[f'upconv_{i}_{0}'](x)
x = [F.interpolate(x, scale_factor=2, mode=self.upsample_mode)]
sf = 2**i
if self.use_skip and sf in self.enc_sc:
idx = self.enc_sc.index(sf)
x += [feat[idx]]
x = torch.cat(x, 1)
x = self.convs[f'upconv_{i}_{1}'](x)
if i in self.out_sc:
out[i] = self.act(self.convs[f'outconv_{i}'](x))
return out
def load_weights(model, state_dict, prefix):
filtered_weights = {}
for key, value in state_dict.items():
if key.startswith(prefix):
new_key = key.replace(prefix, '')
filtered_weights[new_key] = value
missing_keys, unexpected_keys = model.load_state_dict(filtered_weights, strict=False)
print(f"Loading weights with prefix '{prefix}':")
print(f"\tTotal number of keys: {len(filtered_weights.keys())}")
print(f"\tNumber of missing keys: {len(missing_keys)}")
print(f"\tNumber of unexpected keys: {len(unexpected_keys)}")
def load_kbr(ckpt_path):
depth_encoder = timm.create_model("convnext_base", features_only=True, pretrained=True)
depth_decoder = MonodepthDecoder(num_ch_enc=depth_encoder.feature_info.channels(), enc_sc=depth_encoder.feature_info.reduction())
checkpoint = torch.load(ckpt_path, map_location='cpu')
load_weights(depth_encoder, checkpoint['state_dict'], 'nets.depth.encoder.')
load_weights(depth_decoder, checkpoint['state_dict'], 'nets.depth.decoders.disp.')
return depth_encoder, depth_decoder
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment