Skip to content

Instantly share code, notes, and snippets.

@readicculus
Created July 21, 2021 20:17
Show Gist options
  • Save readicculus/dc46514eb997ef8ea286f813621b65e3 to your computer and use it in GitHub Desktop.
Save readicculus/dc46514eb997ef8ea286f813621b65e3 to your computer and use it in GitHub Desktop.
## Instructions
# 1. first generate a list in a text file of all IR .tif images, one filename per line, it is ok to include files from different
# cameras and flights but if a flight_camera is included in the list you should include all images from that flight_camera.
# 2. Preprocess(extract features). Run the script with the preprocess target and list from step 1
# ex 'python nuc.py preprocess --image_list images.txt'
# this will take some time and in the end will save an output.csv in the same directory that the script was run from
# you can override the output default name by specifying the flag --csv_out custom_name.csv
#
# Once you have this output.csv there are two commands available for getting the NUCs.
# 1. To list the NUCs use the list command and pass in the csv that step 2 generated.
# ex. 'python nuc.py list --csv_in output.csv'
# use 'python nuc.py list --help' to see other parameters that can be customized(defaults should be good)
# this will generate a file nucs.txt in the current directory with the images that are found to be NUCs.
# To view the NUC detecetions in human viewable format use the --show flag.
# (WARNING --show will copy all detection images so if there are over 10,000 probably a bad idea)
# 2. If there are lots of false positive NUC's or false negatives can debug easily by generating a plot of Z scores
# for each flight and camera, then look at this graph to find a good Z cuttoff/threshold.
# To visualize the Z-scores run 'python nuc.py plot --csv_in output.csv'
# Then once a good Z cuttoff is determined, for example 2.0, use the --z_thresh 2.0 with the list command.
# ex. 'python nuc.py list --csv_in output.csv --z_thresh 2.0'
import copy
import os
from datetime import datetime
from typing import List
import argparse
import cv2
import numpy as np
import pandas as pd
import re
from tqdm.asyncio import tqdm
pd.plotting.register_matplotlib_converters()
flight_pattern = re.compile('fl\d\d')
def parser():
parser = argparse.ArgumentParser(description="NUC Detection")
subparsers = parser.add_subparsers(dest='command_name')
parser_pre = subparsers.add_parser('preprocess', help="Preprocess imagery and generate the features csv. run 'preprocess -h' for more help")
parser_pre.add_argument("--image_list", type=str, required=True,
help="path to file with list of image filepaths of ir images, one per line")
parser_pre.add_argument("--csv_out", type=str, default="output.csv",
help="path to save the resulting csv file with features and Z scores")
parser_pre.add_argument("--ksize", type=int, default=5,
help="kernel size, if given 5 will use gaussian kernel 5x5")
parser_pre.add_argument("--kstep", type=int, default=1,
help="step of kernel over image")
parser_plot = subparsers.add_parser('plot', help="Plot timeseries of Z scores. run 'plot -h' for more help")
parser_plot.add_argument("--csv_in", type=str, required=True,
help="csv file path for csv file generated by preprocess")
parser_plot.add_argument("--z_thresh", type=float, default=4.0,
help="Z score threshold")
parser_plot.add_argument("--out_dir", type=str, default='charts/',
help="Directory to save charts to.")
parser_list = subparsers.add_parser('list', help="Generate a list of outlier images. run 'list -h' for more help")
parser_list.add_argument("--csv_in", type=str, required=True,
help="csv file path for csv file generated by preprocess")
parser_list.add_argument("--z_thresh", type=float, default=4.0,
help="Z score threshold")
parser_list.add_argument("--list_out", type=str, default='nucs.txt',
help="Output file to write outlier images to")
parser_list.add_argument("--show", action='store_true',
help="Save the detected NUC images to a folder called nuc_list in a human viewable format.")
return parser.parse_args()
#####################################################################
# Helper functions
#####################################################################
# Min/max normalize a 1 channel IR image
def _normalize(im):
im_ir = ((im - np.min(im)) / (0.0 + np.max(im) - np.min(im)))
im_ir = im_ir*255.0
im_ir = im_ir.astype(np.uint8)
return im_ir
# Given an image, a kernel, and a step size take a gaussian blur of the image and
# return the MSE across all pixels using the original image given, and the blurred image.
def _ssd_mse(im: np.array, kernel, step: int):
blur_non = cv2.GaussianBlur(im, kernel, step)
last_blur_non = cv2.GaussianBlur(blur_non, kernel, step)
ssd_blur_non = np.sum((last_blur_non - blur_non)**2)
return ssd_blur_non
# Given an array return an array with Z scores for each value in the given array
def _Z_scores(y):
mu = y.mean()
std = y.std()
z = (y-mu)/std
return z
# Parse a image filename and return the flight/cam string id ex 'fl05_C'
def _fn_key(file_name):
name_parts = file_name.split('_')
start_idx = 0
# find the 'flxx' section
for i in range(len(name_parts)):
part = name_parts[i]
if flight_pattern.match(part):
start_idx = i
break
flight = name_parts[start_idx]
cam = name_parts[start_idx + 1]
return '%s_%s' % (flight, cam)
def _timestamp(file_name):
name_parts = file_name.split('_')
start_idx = 0
# find the 'flxx' section
for i in range(len(name_parts)):
part = name_parts[i]
if flight_pattern.match(part):
start_idx = i
break
day = name_parts[start_idx + 2]
time = name_parts[start_idx + 3]
ms = time.split('.')[1]
time = time.split('.')[0]
day_hr_ms = day + '_' + time
ts = datetime.strptime(day_hr_ms, "%Y%m%d_%H%M%S").timestamp() + float('.' + ms)
timestamp = datetime.fromtimestamp(ts)
return timestamp
def _plot_flight(zs, times, flc, z_thresh, out_dir):
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
os.makedirs(out_dir, exist_ok=True)
# generate a boolean mask for outliers
def mask_outside(z_score, thresh):
outside = -thresh < z_score
outside2 = z_score < thresh
mask = ~np.logical_and(outside, outside2)
return mask
outside = mask_outside(zs, z_thresh)
# setup figure
fig, ax = plt.subplots()
fig.set_size_inches(35, 8)
# plot mses over time
ax.plot(times, zs)
# plot masks aka regions outside of confidence intervals
ax.fill_between(times, 0, 1, where=outside,
color='green', alpha=0.2, transform=ax.get_xaxis_transform())
# plot confidence intervals
ax.fill_between(times, -z_thresh, z_thresh, color='b', alpha=.2)
plt.title('%s Normalize IR Gaussian Noise MSE over time' % flc)
plt.xlabel('time')
plt.ylabel('Z score')
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
ax.xaxis.set_major_locator(mdates.MinuteLocator(interval=10))
ax.set_ylim([-6, max(zs) + 2])
ax.set_xlim([times[0], times[-1]])
# save fig
fig_fn = flc + '.png'
plt.savefig(os.path.join(out_dir, fig_fn))
# plt.show()
#####################################################################
# Commands
#####################################################################
def cmd_preprocess(image_paths: List[str], kernel=(7,7), step=1, save_fn = 'output.csv'):
results = {}
image_paths_todo = image_paths
if True:
progress_df = pd.read_csv('progress.csv')
image_paths_todo = progress_df['filepath'].to_list()
image_paths_todo = set(image_paths) - set(image_paths_todo)
print(f'Progress left: {len(image_paths_todo)}')
with tqdm(total=len(image_paths_todo)) as pbar:
for idx, im_fp in enumerate(image_paths_todo):
img = cv2.imread(im_fp, cv2.IMREAD_UNCHANGED) # load raw IR Image
if img is None:
print('Failed to read: %s' % im_fp)
continue
raw_value = _ssd_mse(img, kernel=kernel, step=step) # calculate mse
normed_value = _ssd_mse(_normalize(img), kernel=kernel, step=step) # normalize then calculate mse
results[im_fp] = {'filepath': im_fp, 'raw_mse': raw_value, 'normed_mse': normed_value}
if idx % 50000 == 0:
pd.DataFrame(results, columns=['filepath', 'raw_mse', 'normed_mse']).to_csv('progress.csv')
# df = df.append({'filepath': im_fp, 'raw_mse': raw_value, 'normed_mse': normed_value}, ignore_index=True)
pbar.update(1)
df = pd.DataFrame(results, columns=['filepath', 'raw_mse', 'normed_mse'])
df_copy = copy.deepcopy(df)
df_copy.loc[:,'flight_cam'] = df_copy.apply(lambda row: _fn_key(os.path.basename(row['filepath'])), axis=1)
df_copy.to_csv(save_fn, index=False)
df_feat = pd.read_csv(save_fn)
gb = df_feat.groupby('flight_cam')
groups = {x:gb.get_group(x) for x in gb.groups}
for k in groups:
df_group = groups[k]
# calculate group Z_scores
df_group['z_score_raw'] = _Z_scores(df_group['raw_mse'])
df_group['z_score_normed'] = _Z_scores(df_group['normed_mse'])
df_out = pd.concat(groups.values(),ignore_index=True)
df_out.to_csv(save_fn, index=False)
def cmd_plot_charts(load_fn, z_thresh, out_dir):
df_feat = pd.read_csv(load_fn)
gb = df_feat.groupby('flight_cam')
groups = {x: gb.get_group(x) for x in gb.groups}
for k in groups:
df_group = groups[k]
df_copy = copy.deepcopy(df_group)
df_copy.loc[:,'ts'] = df_copy.apply(lambda row: _timestamp(os.path.basename(row['filepath'])), axis=1).values
df_copy = df_copy.sort_values('ts')
zs = df_copy['z_score_raw'].values
times = df_copy['ts'].values
_plot_flight(zs, times, k, z_thresh, out_dir)
def cmd_list_outliers(load_fn, save_fn, thresh, show):
df = pd.read_csv(load_fn)
df_outlier = df[df['z_score_raw'] > thresh]
outlier_fps = list(df_outlier['filepath'])
with open(save_fn, 'w') as f:
for fp in outlier_fps:
f.write('%s\n'%fp)
if show:
os.makedirs('nucs_list', exist_ok = True)
for fp in outlier_fps:
im = _normalize(cv2.imread(fp, -1))
fn = os.path.basename(fp).replace('.tif', '.jpg')
new_fp = os.path.join('nucs_list', fn)
im_color = cv2.cvtColor(im,cv2.COLOR_GRAY2RGB)
cv2.imwrite(new_fp, im_color)
print('Saved %d images to the nucs_list folder' % len(outlier_fps))
print('Wrote %d outliers to %s' % (len(outlier_fps), save_fn))
#####################################################################
# Program entry point
#####################################################################
def main():
args=parser()
command = args.command_name
if command == 'preprocess': # preprocess
list_fp = args.image_list
with open(list_fp, 'r') as f:
image_list = [line.rstrip('\n') for line in f]
cmd_preprocess(
image_paths=image_list,
kernel=(args.ksize, args.ksize),
step=args.kstep,
save_fn=args.csv_out
)
elif command == 'plot': # generate plots
cmd_plot_charts(args.csv_in, args.z_thresh, args.out_dir)
elif command == 'list': # save to list
cmd_list_outliers(args.csv_in, args.list_out, args.z_thresh, args.show)
else:
print('Invalid command try using the -h or --help flag.')
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment