Last active
September 4, 2022 16:33
-
-
Save ndrplz/ac164425c8e4d19b8a14566f497eef9b to your computer and use it in GitHub Desktop.
Define and train the homography baseline for "Learning to Map Vehicles into Bird's Eye View"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
HomographyModel maps the bottom of the frontal frame bbox to the birdeye view. | |
Length and width of the bounding box are evaluated as mean across training. | |
""" | |
import os | |
import cv2 | |
import pickle | |
import numpy as np | |
from glob import glob | |
from os.path import join | |
from tqdm import tqdm | |
class HomographyModel: | |
""" | |
This class models a baseline that employs the homography to perform the mapping. | |
The bottom center point of the frontal view is mapped to the bottom center point in the birdeye view. | |
h and w are averaged across the training set. | |
""" | |
def __init__(self): | |
self.H = None | |
self.src_points = [] | |
self.dst_points = [] | |
self.h_sum = 0 | |
self.w_sum = 0 | |
self.n_images = 0 | |
def fit(self, frontal_coords, birdeye_coords): | |
""" | |
Updates the lists of src and dst points based on samples. | |
:param frontal_coords: frontal coordinates in a (batchsize, 4) array. (x_min, y_min, x_max, y_max) | |
:param birdeye_coords: birdeye coordinates in a (batchsize, 4) array. (x_min, y_min, x_max, y_max) | |
:return: None | |
""" | |
batchsize = frontal_coords.shape[0] | |
for b in range(batchsize): | |
xf_min, yf_min, xf_max, yf_max = frontal_coords[b] | |
xb_min, yb_min, xb_max, yb_max = birdeye_coords[b] | |
cur_src_points = [(xf_min, yf_max), (xf_max, yf_max)] | |
cur_dst_points = [(xb_min, yb_max), (xb_max, yb_max)] | |
self.src_points += cur_src_points | |
self.dst_points += cur_dst_points | |
self.h_sum += (yb_max - yb_min) | |
self.w_sum += (xb_max - xb_min) | |
self.n_images += 1 | |
# force homography revaluation | |
if self.H is not None: | |
self.H = None | |
def predict(self, coords): | |
""" | |
Predicts some coordinates in the frontal frame. | |
:param coords: batch of frontal coordinates as a numpy array (batchsize, 4). (x_min, y_min, x_max, y_max) | |
:return: pred: batch of birdeye coordinates as a numpy array (batchsize, 4). (x_min, y_min, x_max, y_max) | |
""" | |
batchsize = coords.shape[0] | |
pred = np.zeros_like(coords) | |
if self.H is None: | |
print('Re-estimating homography...') | |
self.H, _ = cv2.findHomography(np.array(self.src_points), np.array(self.dst_points)) | |
for b in range(0, batchsize): | |
x_min, y_min, x_max, y_max = coords[b] | |
# transform to birdeye bottom center | |
bbc = np.dot(self.H, np.array([(x_min, y_max, 1), (x_max, y_max, 1)]).T) | |
bbc /= bbc[2] | |
h = self.h_sum / self.n_images | |
w = self.w_sum / self.n_images | |
px_min = bbc[0, 0] | |
py_min = (bbc[1, 0] + bbc[1, 1]) / 2 - h | |
px_max = bbc[0, 1] | |
py_max = (bbc[1, 0] + bbc[1, 1]) / 2 | |
pred[b] = np.array([px_min, py_min, px_max, py_max], dtype=pred.dtype) | |
return pred | |
def save_weights(self, filename): | |
""" | |
Saves the model in a pickle file. | |
:param filename: the name of the file to store the model in. | |
:return: None | |
""" | |
directory = os.path.dirname(filename) | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
dump = dict() | |
dump['H'] = self.H | |
dump['src_points'] = self.src_points | |
dump['dst_points'] = self.dst_points | |
dump['h_sum'] = self.h_sum | |
dump['w_sum'] = self.w_sum | |
dump['n_images'] = self.n_images | |
with open(filename, 'wb') as handle: | |
pickle.dump(dump, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
def load_weights(self, filename): | |
""" | |
Loads a file from a pickle file. | |
:param filename: the name of the file to load the model from. | |
:return: None | |
""" | |
with open(filename, 'rb') as handle: | |
dump = pickle.load(handle) | |
self.H = dump['H'] | |
self.src_points = dump['src_points'] | |
self.dst_points = dump['dst_points'] | |
self.h_sum = dump['h_sum'] | |
self.w_sum = dump['w_sum'] | |
self.n_images = dump['n_images'] | |
def read_lines_from_file(filename): | |
""" | |
Reads line from a text file. | |
:param filename: The file to be read. | |
:return: list of strings. | |
""" | |
with open(filename, mode='r') as f: | |
content = f.readlines() | |
return content | |
def get_training_data(): | |
""" | |
Helper function to get training data. | |
:return: X_train, Y_train | |
""" | |
# prepare data | |
print('Preparing training data...') | |
X_train = [] | |
Y_train = [] | |
for train_file in glob(join(dataset_train_root, '**', 'filtered_data.txt')): | |
lines = read_lines_from_file(train_file) | |
new_x, new_y = zip(list(zip(*[(l.split(',')[6:10], l.split(',')[10:]) for l in lines]))) | |
X_train += new_x[0] | |
Y_train += new_y[0] | |
X_train = np.array(X_train, dtype=np.float32) | |
Y_train = np.array(Y_train, dtype=np.float32) | |
return X_train, Y_train | |
def train_baseline_model(model, batchsize=32, model_filename=None): | |
""" | |
Function to train a baseline with GTAV data. | |
:param model: the model we want to train (either homography or grid). | |
:param batchsize: batchsize. Not influent at all. | |
:param model_filename: optional file to save the model to at the end of training. | |
:return: None | |
""" | |
X_train, Y_train = get_training_data() | |
print('Starting training.') | |
for b in tqdm(range(0, X_train.shape[0], batchsize)): | |
start = b | |
stop = min(X_train.shape[0], (b + batchsize)) | |
# Compute multiplicative factor to eventually use original X and Y coordinates range | |
pixel_space = False | |
x_pix_max, y_pix_max = 1920, 1080 | |
if pixel_space: | |
mul = np.expand_dims(np.array([x_pix_max, y_pix_max, x_pix_max, y_pix_max], dtype=np.float32), | |
axis=0) | |
else: | |
mul = 1. | |
batch_X = X_train[start:stop] * mul | |
batch_Y = Y_train[start:stop] * mul | |
model.fit(batch_X, batch_Y) | |
if model_filename is not None: | |
model.save_weights(model_filename) | |
if __name__ == '__main__': | |
dataset_train_root = join('<your_path_here>/GTA_dataset', 'train') | |
model = HomographyModel() | |
train_baseline_model(model, batchsize=32, model_filename='pretrained/homography_model.pickle') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment