Skip to content

Instantly share code, notes, and snippets.

@lucasdavid
Last active December 16, 2022 18:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lucasdavid/c48d39399b0bbe167cc2f5c056beadf0 to your computer and use it in GitHub Desktop.
Save lucasdavid/c48d39399b0bbe167cc2f5c056beadf0 to your computer and use it in GitHub Desktop.
Test which option translates segmentation masks pixels faster
import argparse
import os
import time
import numpy as np
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('--opt', default='a', choices=['a', 'b'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--iterations', type=int, default=1000)
parser.add_argument('--data-dir', default='~/VOCdevkit/VOC2012')
# --- Benchmark ---
def mask_decode_py(mask: tf.Tensor) -> tf.Tensor:
if args.opt == 'a':
mask = mask[..., 0] * 256 * 256 + mask[..., 1] * 256 + mask[..., 2]
mask = tf.gather(COLOR_MAP, mask)
else:
mask = mask[..., 0] * 256**2 + mask[..., 1] * 256 + mask[..., 2]
mask = tf.argmax(mask[..., tf.newaxis] == COLOR_MAP, axis=-1) # HW, C
return mask[..., tf.newaxis]
py_fn_input_signature = [tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.int32)]
mask_decode_jit = tf.function(
func=mask_decode_py,
reduce_retracing=True,
jit_compile=True,
input_signature=py_fn_input_signature)
mask_decode_jit.__name__ = 'mask_decode_jit'
mask_decode_tf = tf.function(
func=mask_decode_py,
reduce_retracing=True,
input_signature=py_fn_input_signature)
mask_decode_tf.__name__ = 'mask_decode_tf'
def sample_from(masks):
masks = np.random.choice(masks, size=args.batch_size, replace=False)
masks = [os.path.join(MASKS_DIR, m) for m in masks]
masks = tf.stack([
tf.image.resize(
tf.io.decode_png(tf.io.read_file(mask_file_path)),
(512, 512)
)
for mask_file_path in masks
])
return tf.cast(masks, tf.int32)
def benchmark(fn, masks):
warmup_time = 0
for i in range(3):
d = sample_from(masks)
start = time.time()
r = fn(d)
warmup_time += time.time() - start
total_time = 0
for i in range(args.iterations):
d = sample_from(masks)
start = time.time()
r = fn(d)
total_time += time.time() - start
print(f"{fn.__name__} | "
f"warmup: {warmup_time / 3:7.4f} s | "
f"total: {total_time / args.iterations:7.4f} s")
return r
if __name__ == "__main__":
args = parser.parse_args()
MASKS_DIR = os.path.expanduser(os.path.join(args.data_dir, 'SegmentationClass'))
ALL_MASKS = os.listdir(MASKS_DIR)
VOC_COLORMAP = np.array([
[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
[64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128],
[192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128], [224, 224, 192]
], dtype='int32')
if args.opt == 'a':
COLOR_MAP = [0] * (256**3)
for i, cm in enumerate(VOC_COLORMAP):
COLOR_MAP[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
# There is a special mapping with [224, 224, 192] -> 255
COLOR_MAP[224 * 256 * 256 + 224 * 256 + 192] = 255
else:
COLOR_MAP = VOC_COLORMAP[:, 0] * 256**2 + VOC_COLORMAP[:, 1] * 256 + VOC_COLORMAP[:, 2]
COLOR_MAP[-1] = 255
COLOR_MAP = tf.constant(COLOR_MAP, dtype=tf.int32)
print(f'Benchmarking opt {args.opt} iterations:{args.iterations}')
benchmark(mask_decode_py, ALL_MASKS)
benchmark(mask_decode_tf, ALL_MASKS)
benchmark(mask_decode_jit, ALL_MASKS)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment