Skip to content

Instantly share code, notes, and snippets.

@readicculus
Last active July 23, 2021 20:05
Show Gist options
  • Save readicculus/44cb4a37a89c2ff03a2d81768e4f4763 to your computer and use it in GitHub Desktop.
Save readicculus/44cb4a37a89c2ff03a2d81768e4f4763 to your computer and use it in GitHub Desktop.
NUC detector
import argparse
import os
import cv2
import numpy as np
from sklearn import svm
from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
import joblib
import pandas as pd
from tqdm.asyncio import tqdm
from multiprocessing import Pool
def parser():
parser = argparse.ArgumentParser(description="NUC SVM Detector")
subparsers = parser.add_subparsers(dest='command_name')
parser_pre = subparsers.add_parser('train', help="Train the svm")
parser_pre.add_argument("--image_list_1", type=str, required=True,
help="path to file with list of NUC image(1) filepaths of ir images, one per line")
parser_pre.add_argument("--image_list_0", type=str, required=True,
help="path to file with list of non-NUC image(0) filepaths of ir images, one per line")
parser_pre.add_argument("--output_model_fn", type=str, default="nuc_detector.pkl",
help="Filename for the output model binaries.")
parser_pre.add_argument("--test_size", type=float, default=0.2,
help="A percentage of the total dataset to exclude from training and to use for testing. [0.0, 1.0]")
parser_plot = subparsers.add_parser('predict', help="Create predictions from an imagelist once a model is trained.")
parser_plot.add_argument("--image_list", type=str, required=True,
help="path to image list file.")
parser_plot.add_argument("--model_fn", type=str, default="nuc_detector.pkl",
help="The model to use for prediction.")
parser_plot.add_argument("--output_predictions_fn", type=str, default="predictions.csv",
help="Filename to save output predictions at.")
return parser.parse_args()
#####################################################################
# Model/Preprocessing Pipeline
#####################################################################
def make_batches(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
class PreprocTransformer(BaseEstimator):
def __init__(self, batch_size=16, n_pools=10):
self.batch_size = batch_size
self.n_pools = n_pools
super().__init__()
def fit(self, X, y=None, **fit_params):
return self
def transform(self, X):
a = list(zip(X, range(0, len(X))))
batches = make_batches(a, self.batch_size)
with Pool(self.n_pools) as p:
feature_batches = list(tqdm(p.imap(self.load_and_extract_feature, batches), total=len(X)//self.batch_size))
# reconstruct batches at same indices
features = np.zeros((len(X), 256))
idxs = []
for b in feature_batches:
features[b[1]] = b[0]
idxs += b[1]
# sanity check that we did put features at every index of the originally empty feature map
assert(set(idxs) == set(range(len(X))))
return features
def load_and_extract_feature(self, image_list):
out = []
idxs = []
for im_fp, idx in image_list:
im = cv2.imread(im_fp, -1)
im_ir = ((im - np.min(im)) / (0.0 + np.max(im) - np.min(im)))
im_ir = im_ir * 255.0
im_ir = im_ir.astype(np.uint8)
hist, b = np.histogram(im_ir, 256)
out.append(hist.astype(int))
idxs.append(idx)
return np.array(out), idxs
#####################################################################
# Helper functions
#####################################################################
def read_image_list_or_dir(list_fp):
if os.path.isfile(list_fp):
with open(list_fp, 'r') as f:
ims = [line.rstrip('\n') for line in f]
base_dir = os.path.abspath(os.path.dirname(list_fp))
else:
ims = os.listdir(list_fp)
base_dir = os.path.abspath(list_fp)
ims_out = []
for im_fp in ims:
if os.path.isabs(im_fp):
ims_out.append(im_fp)
else:
ims_out.append(os.path.join(base_dir, im_fp))
return ims_out
#####################################################################
# Commands
#####################################################################
def train(image_list_0, image_list_1, test_size):
ims_all = image_list_0 + image_list_1
labels_all = np.concatenate((np.zeros(len(image_list_0)), np.ones(len(image_list_1))), axis=0).astype(int)
data_train, data_test, labels_train, labels_test = train_test_split(ims_all, labels_all, test_size=test_size,
random_state=42)
clf = svm.SVC()
sk_pipe = Pipeline([("trans_norm", PreprocTransformer(batch_size=16, n_pools=10)), ("clf", clf)])
print(f'Training model...')
sk_pipe.fit(data_train, labels_train)
print(f'Evaluating on test set...')
test_pred = sk_pipe.predict(data_test)
t = accuracy_score(labels_test, test_pred)
print(f'Accuracy: {t}')
return sk_pipe
def predict(image_list, model_fn='nuc_detector.pkl'):
model = joblib.load(model_fn)
return model.predict(image_list)
#####################################################################
# Program entry point
#####################################################################
def main():
args=parser()
command = args.command_name
if command == 'train': # preprocess
image_list_1 = read_image_list_or_dir(args.image_list_1)
image_list_0 = read_image_list_or_dir(args.image_list_0)
output_model_fn = args.output_model_fn
test_size = args.test_size
model = train(image_list_0, image_list_1, test_size)
print(f'Saving model to {output_model_fn}')
joblib.dump(model, output_model_fn)
elif command == 'predict': # generate plots
image_list = read_image_list_or_dir(args.image_list)
print(f'Predicting for {len(image_list)} images.')
predictions = predict(image_list, args.model_fn)
res_dict = []
for image_fp, prediction in zip(image_list, predictions):
res_dict.append({'image': image_fp, 'is_nuc': prediction})
df = pd.DataFrame(res_dict)
print(f'Saving predictions to {args.output_predictions_fn}')
df.to_csv(args.output_predictions_fn, index=False)
else:
print('Invalid command try using the -h or --help flag.')
if __name__ == "__main__":
main()
## Instructions
# 1. first generate a list in a text file of all IR .tif images, one filename per line, it is ok to include files from different
# cameras and flights but if a flight_camera is included in the list you should include all images from that flight_camera.
# 2. Preprocess(extract features). Run the script with the preprocess target and list from step 1
# ex 'python nuc.py preprocess --image_list images.txt'
# this will take some time and in the end will save an output.csv in the same directory that the script was run from
# you can override the output default name by specifying the flag --csv_out custom_name.csv
#
# Once you have this output.csv there are two commands available for getting the NUCs.
# 1. To list the NUCs use the list command and pass in the csv that step 2 generated.
# ex. 'python nuc.py list --csv_in output.csv'
# use 'python nuc.py list --help' to see other parameters that can be customized(defaults should be good)
# this will generate a file nucs.txt in the current directory with the images that are found to be NUCs.
# To view the NUC detecetions in human viewable format use the --show flag.
# (WARNING --show will copy all detection images so if there are over 10,000 probably a bad idea)
# 2. If there are lots of false positive NUC's or false negatives can debug easily by generating a plot of Z scores
# for each flight and camera, then look at this graph to find a good Z cuttoff/threshold.
# To visualize the Z-scores run 'python nuc.py plot --csv_in output.csv'
# Then once a good Z cuttoff is determined, for example 2.0, use the --z_thresh 2.0 with the list command.
# ex. 'python nuc.py list --csv_in output.csv --z_thresh 2.0'
import copy
import os
from datetime import datetime
from typing import List
import argparse
import cv2
import numpy as np
import pandas as pd
import re
from tqdm.asyncio import tqdm
pd.plotting.register_matplotlib_converters()
flight_pattern = re.compile('fl\d\d')
def parser():
parser = argparse.ArgumentParser(description="NUC Detection")
subparsers = parser.add_subparsers(dest='command_name')
parser_pre = subparsers.add_parser('preprocess', help="Preprocess imagery and generate the features csv. run 'preprocess -h' for more help")
parser_pre.add_argument("--image_list", type=str, required=True,
help="path to file with list of image filepaths of ir images, one per line")
parser_pre.add_argument("--csv_out", type=str, default="output.csv",
help="path to save the resulting csv file with features and Z scores")
parser_pre.add_argument("--ksize", type=int, default=5,
help="kernel size, if given 5 will use gaussian kernel 5x5")
parser_pre.add_argument("--kstep", type=int, default=1,
help="step of kernel over image")
parser_plot = subparsers.add_parser('plot', help="Plot timeseries of Z scores. run 'plot -h' for more help")
parser_plot.add_argument("--csv_in", type=str, required=True,
help="csv file path for csv file generated by preprocess")
parser_plot.add_argument("--z_thresh", type=float, default=4.0,
help="Z score threshold")
parser_plot.add_argument("--out_dir", type=str, default='charts/',
help="Directory to save charts to.")
parser_list = subparsers.add_parser('list', help="Generate a list of outlier images. run 'list -h' for more help")
parser_list.add_argument("--csv_in", type=str, required=True,
help="csv file path for csv file generated by preprocess")
parser_list.add_argument("--z_thresh", type=float, default=4.0,
help="Z score threshold")
parser_list.add_argument("--list_out", type=str, default='nucs.txt',
help="Output file to write outlier images to")
parser_list.add_argument("--show", action='store_true',
help="Save the detected NUC images to a folder called nuc_list in a human viewable format.")
return parser.parse_args()
#####################################################################
# Helper functions
#####################################################################
# Min/max normalize a 1 channel IR image
def _normalize(im):
im_ir = ((im - np.min(im)) / (0.0 + np.max(im) - np.min(im)))
im_ir = im_ir*255.0
im_ir = im_ir.astype(np.uint8)
return im_ir
# Given an image, a kernel, and a step size take a gaussian blur of the image and
# return the MSE across all pixels using the original image given, and the blurred image.
def _ssd_mse(im: np.array, kernel, step: int):
blur_non = cv2.GaussianBlur(im, kernel, step)
last_blur_non = cv2.GaussianBlur(blur_non, kernel, step)
ssd_blur_non = np.sum((last_blur_non - blur_non)**2)
return ssd_blur_non
# Given an array return an array with Z scores for each value in the given array
def _Z_scores(y):
mu = y.mean()
std = y.std()
z = (y-mu)/std
return z
# Parse a image filename and return the flight/cam string id ex 'fl05_C'
def _fn_key(file_name):
name_parts = file_name.split('_')
start_idx = 0
# find the 'flxx' section
for i in range(len(name_parts)):
part = name_parts[i]
if flight_pattern.match(part):
start_idx = i
break
flight = name_parts[start_idx]
cam = name_parts[start_idx + 1]
return '%s_%s' % (flight, cam)
def _timestamp(file_name):
name_parts = file_name.split('_')
start_idx = 0
# find the 'flxx' section
for i in range(len(name_parts)):
part = name_parts[i]
if flight_pattern.match(part):
start_idx = i
break
day = name_parts[start_idx + 2]
time = name_parts[start_idx + 3]
ms = time.split('.')[1]
time = time.split('.')[0]
day_hr_ms = day + '_' + time
ts = datetime.strptime(day_hr_ms, "%Y%m%d_%H%M%S").timestamp() + float('.' + ms)
timestamp = datetime.fromtimestamp(ts)
return timestamp
def _plot_flight(zs, times, flc, z_thresh, out_dir):
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
os.makedirs(out_dir, exist_ok=True)
# generate a boolean mask for outliers
def mask_outside(z_score, thresh):
outside = -thresh < z_score
outside2 = z_score < thresh
mask = ~np.logical_and(outside, outside2)
return mask
outside = mask_outside(zs, z_thresh)
# setup figure
fig, ax = plt.subplots()
fig.set_size_inches(35, 8)
# plot mses over time
ax.plot(times, zs)
# plot masks aka regions outside of confidence intervals
ax.fill_between(times, 0, 1, where=outside,
color='green', alpha=0.2, transform=ax.get_xaxis_transform())
# plot confidence intervals
ax.fill_between(times, -z_thresh, z_thresh, color='b', alpha=.2)
plt.title('%s Normalize IR Gaussian Noise MSE over time' % flc)
plt.xlabel('time')
plt.ylabel('Z score')
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
ax.xaxis.set_major_locator(mdates.MinuteLocator(interval=10))
ax.set_ylim([-6, max(zs) + 2])
ax.set_xlim([times[0], times[-1]])
# save fig
fig_fn = flc + '.png'
plt.savefig(os.path.join(out_dir, fig_fn))
# plt.show()
#####################################################################
# Commands
#####################################################################
def cmd_preprocess(image_paths: List[str], kernel=(7,7), step=1, save_fn = 'output.csv'):
df = pd.DataFrame(columns=['filepath', 'raw_mse', 'normed_mse'])
with tqdm(total=len(image_paths)) as pbar:
for idx, im_fp in enumerate(image_paths):
img = cv2.imread(im_fp, cv2.IMREAD_UNCHANGED) # load raw IR Image
if img is None:
print('Failed to read: %s' % im_fp)
continue
raw_value = _ssd_mse(img, kernel=kernel, step=step) # calculate mse
normed_value = _ssd_mse(_normalize(img), kernel=kernel, step=step) # normalize then calculate mse
df = df.append({'filepath': im_fp, 'raw_mse': raw_value, 'normed_mse': normed_value}, ignore_index=True)
pbar.update(1)
df_copy = copy.deepcopy(df)
df_copy.loc[:,'flight_cam'] = df_copy.apply(lambda row: _fn_key(os.path.basename(row['filepath'])), axis=1)
df_copy.to_csv(save_fn, index=False)
df_feat = pd.read_csv(save_fn)
gb = df_feat.groupby('flight_cam')
groups = {x:gb.get_group(x) for x in gb.groups}
for k in groups:
df_group = groups[k]
# calculate group Z_scores
df_group['z_score_raw'] = _Z_scores(df_group['raw_mse'])
df_group['z_score_normed'] = _Z_scores(df_group['normed_mse'])
df_out = pd.concat(groups.values(),ignore_index=True)
df_out.to_csv(save_fn, index=False)
def cmd_plot_charts(load_fn, z_thresh, out_dir):
df_feat = pd.read_csv(load_fn)
gb = df_feat.groupby('flight_cam')
groups = {x: gb.get_group(x) for x in gb.groups}
for k in groups:
df_group = groups[k]
df_copy = copy.deepcopy(df_group)
df_copy.loc[:,'ts'] = df_copy.apply(lambda row: _timestamp(os.path.basename(row['filepath'])), axis=1).values
df_copy = df_copy.sort_values('ts')
zs = df_copy['z_score_raw'].values
times = df_copy['ts'].values
_plot_flight(zs, times, k, z_thresh, out_dir)
def cmd_list_outliers(load_fn, save_fn, thresh, show):
df = pd.read_csv(load_fn)
df_outlier = df[df['z_score_raw'] > thresh]
outlier_fps = list(df_outlier['filepath'])
with open(save_fn, 'w') as f:
for fp in outlier_fps:
f.write('%s\n'%fp)
if show:
os.makedirs('nucs_list', exist_ok = True)
for fp in outlier_fps:
im = _normalize(cv2.imread(fp, -1))
fn = os.path.basename(fp).replace('.tif', '.jpg')
new_fp = os.path.join('nucs_list', fn)
im_color = cv2.cvtColor(im,cv2.COLOR_GRAY2RGB)
cv2.imwrite(new_fp, im_color)
print('Saved %d images to the nucs_list folder' % len(outlier_fps))
print('Wrote %d outliers to %s' % (len(outlier_fps), save_fn))
#####################################################################
# Program entry point
#####################################################################
def main():
args=parser()
command = args.command_name
if command == 'preprocess': # preprocess
list_fp = args.image_list
with open(list_fp, 'r') as f:
image_list = [line.rstrip('\n') for line in f]
cmd_preprocess(
image_paths=image_list,
kernel=(args.ksize, args.ksize),
step=args.kstep,
save_fn=args.csv_out
)
elif command == 'plot': # generate plots
cmd_plot_charts(args.csv_in, args.z_thresh, args.out_dir)
elif command == 'list': # save to list
cmd_list_outliers(args.csv_in, args.list_out, args.z_thresh, args.show)
else:
print('Invalid command try using the -h or --help flag.')
if __name__ == "__main__":
main()
matplotlib>=3.3.2
numpy>=1.19.2
opencv-python>=4.4.0.44
pandas>=1.1.3
pyparsing>=2.4.7
python-dateutil>=2.8.1
tqdm>=4.50.2
@readicculus
Copy link
Author

Example output of using the plot command. The y axis is the z-score. Each point represents an image in the timeseries. Highlighted green bands are images that fall outside of the z-score threshold(--z-thresh argument). As we can see when the camera has to correct a continuous chunk of images get nuced ending once the camera finishes re-calibrating.

fl05_L_mse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment