Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Created January 15, 2020 18:03
Show Gist options
  • Save andres-fr/25f63f33e155cd10d39a636740050308 to your computer and use it in GitHub Desktop.
Save andres-fr/25f63f33e155cd10d39a636740050308 to your computer and use it in GitHub Desktop.
A Python3 script to mask arrays by customizable criteria (e.g. color spheres)
# -*- coding:utf-8 -*-
"""
This script deals with extracting value-based masks from images in various ways.
At some point it might be sensible to use OpenCV to scale further up.
https://www.learnopencv.com/tag/inrange/
Explanation of PIL HSV space:
https://pillow.readthedocs.io/en/latest/reference/ImageColor.html?highlight=hsv#color-names
CLI usage examples (jan 2020)::
python colors_to_binary_mask.py -C "[255, 0, 0]" -I ./*.bmp
python colors_to_binary_mask.py -C "255 0 0" -I ./strymonas_0000181.bmp -P -A rgb_range_pick -r 40
python colors_to_binary_mask.py -C "255 0 0" -I ./bad_masks/*.bmp -A rgb_range_pick -r 40 -O ./bad_masks
"""
import argparse
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# ##############################################################################
# ## HELPERS
# ##############################################################################
def plot_2_rgb_arrays(arr1, arr2, share_zoom=True):
"""
:param share_zoom: If true, zooming left will correspondingly zoom right.
Plots arr1 left and arr2 right.
"""
fig, axarr = plt.subplots(1, 2, sharex=share_zoom, sharey=share_zoom)
axarr[0].imshow(arr1)
axarr[1].imshow(arr2)
fig.tight_layout()
fig.show()
input("\n\nPlotting. Press any key to continue...")
plt.clf()
def parse_color_str(colors_str):
"""
Helper function to translate the string input from the argparser into
a numpy array.
Specifically, expects a string in the form 'a b c' containing three 8-bit
unsigned integers. Returns a 3-dimensional np.uint8 array.
"""
l = [str(x) for x in colors_str.split(" ")]
arr = np.array(l, dtype=np.uint8)
return arr
# ##############################################################################
# ## ARR TO BINARY MASK
# ##############################################################################
def mask_by_value_set(arr, values):
"""
:param arr: A np.uint8 array of shape ``(h, w, c)``. If you are feeding
black and white images, add one channel via ``arr[:, :, np.newaxis]``
:param values: A list in the form ``[l1, l2, ...]`` where each l has length
``c``.
:returns: A np.bool mask of shape ``(h, w)`` with ``True`` for
any pixel in ``arr`` that matched any of the given values.
For ``c=3`` this can be used, e.g., to pick RGB colors.
"""
# sanity check
assert arr.dtype == np.uint8, \
"dtype np.uint8 expected and was {}".format(arr.dtype)
assert len(arr.shape) == 3, \
"Expected (h, w, c), shape was {}".format(arr.shape)
# prepare input: collapse all given values by channel
h, w, c = arr.shape
assert all([len(x) == c for x in values]), \
"Values must all have same length as num_channels(arr)!"
value_sets = zip(*values)
# make mask by filtering channel by channel
mask = np.ones((h, w), dtype=np.bool)
for i, vset in enumerate(value_sets):
mask *= np.isin(arr[:, :, i], vset)
#
return mask
def mask_by_value_ranges(arr, values, eucl_range=0):
"""
:param arr: A np.uint8 array of shape ``(h, w, c)``. If you are feeding
black and white images, add one channel via ``arr[:, :, np.newaxis]``
:param values: A list in the form ``[l1, l2, ...]`` where each l has length
``c``.
:param eucl_range: A non-negative float, determining the maximal distance
between ``value`` and any accepted pixel.
:returns: A np.bool mask of shape ``(h, w)`` with ``True`` for
any pixel in ``arr`` that matched any value in the given range.
For any value ``v`` in values, and a range ``r``, every pixel ``w`` that
satisfies ``l2_norm(v - w) <= r`` will be set as ``True``, false otherwise.
E.g. if I give a blue and a red as values, and a range of 50, any color
close to red or to blue will be set as True.
"""
# sanity check
assert arr.dtype == np.uint8, \
"dtype np.uint8 expected and was {}".format(arr.dtype)
assert len(arr.shape) == 3, \
"Expected (h, w, c), shape was {}".format(arr.shape)
h, w, c = arr.shape
assert all([len(x) == c for x in values]), \
"Values must all have same length as num_channels(arr)!"
# compute distances and mask by distance
mask = np.zeros((h, w), dtype=np.bool)
for v in values:
dist = np.linalg.norm(arr.astype(np.float32) - v.astype(np.float32),
axis=2)
mask |= (dist <= eucl_range) # add (logical_or) to result
#
return mask
# ##############################################################################
# ## APP-LEVEL OPS ON MASK
# ##############################################################################
def make_color_mask(mask, true_color, false_color):
"""
:param mask: A numpy array of shape ``(h, w)``
:param true_color: 3D np.uint8 array in RGB format
:param false_color: 3D np.uint8 array in RGB format
:returns: np.uint8 array of shape ``(h, w, 3)``, colored RGB version of the
given mask.
"""
# sanity check
assert mask.dtype == np.bool, \
"dtype np.bool expected and was {}".format(mask.dtype)
assert len(mask.shape) == 2, \
"Expected binary 2D mask, shape was {}".format(mask.shape)
#
h, w = mask.shape
result = np.zeros((h, w, 3), dtype=np.uint8)
result[:] = false_color
result[np.where(mask)] = true_color
return result
def colorize_and_save_mask(mask, true_color, false_color, plot=False,
out_path=""):
"""
"""
color_mask = make_color_mask(mask, true_color, false_color)
# (optional) plotting
if plot:
plot_2_rgb_arrays(arr, color_mask)
# (optional) saving
if out_path:
try:
img_mask = Image.fromarray(color_mask)
img_mask.save(out_path)
print("Saved png to {}!".format(out_path))
except Exception as e:
print("Saving went wrong! path: {}, exception: {}".format(
out_path, e))
# ##############################################################################
# ## MAIN ROUTINE
# ##############################################################################
if __name__ == "__main__":
# arg parser
parser = argparse.ArgumentParser(description="BMP->Mask")
parser.add_argument("-A", "--action", default="rgb_set_pick", type=str,
help="The criterion to elaborate the mask.")
parser.add_argument("-I", "--image_paths", required=True,
nargs="+", type=str,
help="Absolute paths for input images")
parser.add_argument("-O", "--out_dir", default="", type=str,
help="Output directory. Names will append _mask.png")
parser.add_argument("-C", "--rgb_colors",
nargs="+", type=str,
help="Colors to look for. E.g. '255 0 0' '1 2 3'")
parser.add_argument("-r", "--color_range", default=0, type=float,
help="Max accepted distance (Eucl.) for rgb_range_pick")
parser.add_argument("-T", "--true_color", type=str, default="255 255 255",
help="Color for true pixels, e.g. '255 255 255'")
parser.add_argument("-F", "--false_color", type=str, default="0 0 0",
help="Color non-true pixels, e.g. '0 0 0'")
parser.add_argument("-P", "--plot", action="store_true",
help="if given, plot img and mask")
args = parser.parse_args()
# globals
ACTION = args.action
IMG_PATHS = args.image_paths
OUT_DIR = args.out_dir
OUT_APPEND = "_mask.png"
RGB_COLORS = [parse_color_str(s) for s in args.rgb_colors]
COLOR_RANGE = args.color_range
TRUE_COLOR = parse_color_str(args.true_color)
FALSE_COLOR = parse_color_str(args.false_color)
PLOT = args.plot
# main loop
for img_path in IMG_PATHS:
# load image into RGB uint8 numpy array
print("Processing", img_path)
img = Image.open(img_path).convert("RGB")
arr = np.array(img)
# create mask based on action
if ACTION == "rgb_set_pick":
assert len(RGB_COLORS) >= 1,\
"This action requires to give RGB colors. Check -h"
binary_mask = mask_by_value_set(arr, RGB_COLORS)
elif ACTION == "rgb_range_pick":
assert len(RGB_COLORS) >= 1,\
"This action requires to give RGB colors. Check -h"
binary_mask = mask_by_value_ranges(arr, RGB_COLORS, COLOR_RANGE)
# show or export mask
savepath = ""
if OUT_DIR:
filename = os.path.splitext(os.path.basename(img_path))[0]
savepath = os.path.join(OUT_DIR, filename + OUT_APPEND)
colorize_and_save_mask(binary_mask, TRUE_COLOR, FALSE_COLOR,
PLOT, savepath)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment