Skip to content

Instantly share code, notes, and snippets.

@ndrplz
Last active September 4, 2022 16:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ndrplz/ac164425c8e4d19b8a14566f497eef9b to your computer and use it in GitHub Desktop.
Save ndrplz/ac164425c8e4d19b8a14566f497eef9b to your computer and use it in GitHub Desktop.
Define and train the homography baseline for "Learning to Map Vehicles into Bird's Eye View"
"""
HomographyModel maps the bottom of the frontal frame bbox to the birdeye view.
Length and width of the bounding box are evaluated as mean across training.
"""
import os
import cv2
import pickle
import numpy as np
from glob import glob
from os.path import join
from tqdm import tqdm
class HomographyModel:
"""
This class models a baseline that employs the homography to perform the mapping.
The bottom center point of the frontal view is mapped to the bottom center point in the birdeye view.
h and w are averaged across the training set.
"""
def __init__(self):
self.H = None
self.src_points = []
self.dst_points = []
self.h_sum = 0
self.w_sum = 0
self.n_images = 0
def fit(self, frontal_coords, birdeye_coords):
"""
Updates the lists of src and dst points based on samples.
:param frontal_coords: frontal coordinates in a (batchsize, 4) array. (x_min, y_min, x_max, y_max)
:param birdeye_coords: birdeye coordinates in a (batchsize, 4) array. (x_min, y_min, x_max, y_max)
:return: None
"""
batchsize = frontal_coords.shape[0]
for b in range(batchsize):
xf_min, yf_min, xf_max, yf_max = frontal_coords[b]
xb_min, yb_min, xb_max, yb_max = birdeye_coords[b]
cur_src_points = [(xf_min, yf_max), (xf_max, yf_max)]
cur_dst_points = [(xb_min, yb_max), (xb_max, yb_max)]
self.src_points += cur_src_points
self.dst_points += cur_dst_points
self.h_sum += (yb_max - yb_min)
self.w_sum += (xb_max - xb_min)
self.n_images += 1
# force homography revaluation
if self.H is not None:
self.H = None
def predict(self, coords):
"""
Predicts some coordinates in the frontal frame.
:param coords: batch of frontal coordinates as a numpy array (batchsize, 4). (x_min, y_min, x_max, y_max)
:return: pred: batch of birdeye coordinates as a numpy array (batchsize, 4). (x_min, y_min, x_max, y_max)
"""
batchsize = coords.shape[0]
pred = np.zeros_like(coords)
if self.H is None:
print('Re-estimating homography...')
self.H, _ = cv2.findHomography(np.array(self.src_points), np.array(self.dst_points))
for b in range(0, batchsize):
x_min, y_min, x_max, y_max = coords[b]
# transform to birdeye bottom center
bbc = np.dot(self.H, np.array([(x_min, y_max, 1), (x_max, y_max, 1)]).T)
bbc /= bbc[2]
h = self.h_sum / self.n_images
w = self.w_sum / self.n_images
px_min = bbc[0, 0]
py_min = (bbc[1, 0] + bbc[1, 1]) / 2 - h
px_max = bbc[0, 1]
py_max = (bbc[1, 0] + bbc[1, 1]) / 2
pred[b] = np.array([px_min, py_min, px_max, py_max], dtype=pred.dtype)
return pred
def save_weights(self, filename):
"""
Saves the model in a pickle file.
:param filename: the name of the file to store the model in.
:return: None
"""
directory = os.path.dirname(filename)
if not os.path.exists(directory):
os.makedirs(directory)
dump = dict()
dump['H'] = self.H
dump['src_points'] = self.src_points
dump['dst_points'] = self.dst_points
dump['h_sum'] = self.h_sum
dump['w_sum'] = self.w_sum
dump['n_images'] = self.n_images
with open(filename, 'wb') as handle:
pickle.dump(dump, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_weights(self, filename):
"""
Loads a file from a pickle file.
:param filename: the name of the file to load the model from.
:return: None
"""
with open(filename, 'rb') as handle:
dump = pickle.load(handle)
self.H = dump['H']
self.src_points = dump['src_points']
self.dst_points = dump['dst_points']
self.h_sum = dump['h_sum']
self.w_sum = dump['w_sum']
self.n_images = dump['n_images']
def read_lines_from_file(filename):
"""
Reads line from a text file.
:param filename: The file to be read.
:return: list of strings.
"""
with open(filename, mode='r') as f:
content = f.readlines()
return content
def get_training_data():
"""
Helper function to get training data.
:return: X_train, Y_train
"""
# prepare data
print('Preparing training data...')
X_train = []
Y_train = []
for train_file in glob(join(dataset_train_root, '**', 'filtered_data.txt')):
lines = read_lines_from_file(train_file)
new_x, new_y = zip(list(zip(*[(l.split(',')[6:10], l.split(',')[10:]) for l in lines])))
X_train += new_x[0]
Y_train += new_y[0]
X_train = np.array(X_train, dtype=np.float32)
Y_train = np.array(Y_train, dtype=np.float32)
return X_train, Y_train
def train_baseline_model(model, batchsize=32, model_filename=None):
"""
Function to train a baseline with GTAV data.
:param model: the model we want to train (either homography or grid).
:param batchsize: batchsize. Not influent at all.
:param model_filename: optional file to save the model to at the end of training.
:return: None
"""
X_train, Y_train = get_training_data()
print('Starting training.')
for b in tqdm(range(0, X_train.shape[0], batchsize)):
start = b
stop = min(X_train.shape[0], (b + batchsize))
# Compute multiplicative factor to eventually use original X and Y coordinates range
pixel_space = False
x_pix_max, y_pix_max = 1920, 1080
if pixel_space:
mul = np.expand_dims(np.array([x_pix_max, y_pix_max, x_pix_max, y_pix_max], dtype=np.float32),
axis=0)
else:
mul = 1.
batch_X = X_train[start:stop] * mul
batch_Y = Y_train[start:stop] * mul
model.fit(batch_X, batch_Y)
if model_filename is not None:
model.save_weights(model_filename)
if __name__ == '__main__':
dataset_train_root = join('<your_path_here>/GTA_dataset', 'train')
model = HomographyModel()
train_baseline_model(model, batchsize=32, model_filename='pretrained/homography_model.pickle')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment