Skip to content

Instantly share code, notes, and snippets.

@shawwn
Created January 29, 2020 09:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawwn/b8cbd683e77b98c060c8f63ed22359f4 to your computer and use it in GitHub Desktop.
Save shawwn/b8cbd683e77b98c060c8f63ed22359f4 to your computer and use it in GitHub Desktop.
Extract a random crop from an image. Works with TPUs and on cloud bucket paths like gs://gpt-2-poetry/test/983951.png
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import datetime
import os
import shutil
import subprocess
import sys
import tensorflow.compat.v1 as tf
class Namespace():
pass
me = Namespace()
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(filename, image_buffer, label_int, label_str, height,
width):
"""Build an Example proto for an example.
Args:
filename: string, path to an image file, e.g., '/path/to/example.JPG'
image_buffer: string, JPEG encoding of RGB image
label_int: integer, identifier for ground truth (0-based)
label_str: string, identifier for ground truth, e.g., 'daisy'
height: integer, image height in pixels
width: integer, image width in pixels
Returns:
Example proto
"""
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label_int +
1), # model expects 1-based
'image/class/synset': _bytes_feature(label_str),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(os.path.basename(filename)),
'image/encoded': _bytes_feature(image_buffer)
}))
return example
class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self, session=None):
# Create a single Session to run all image coding calls.
session = tf.get_default_session() if session is None else session
self._sess = tf.Session() if session is None else session
# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that converts CMYK JPEG data to RGB JPEG data.
self._cmyk_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_jpeg(self._cmyk_data, channels=0)
self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def __del__(self):
self._sess.close()
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def cmyk_to_rgb(self, image_data):
return self._sess.run(self._cmyk_to_rgb,
feed_dict={self._cmyk_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
# Parse individual image from a tfrecords file into TensorFlow expression.
def parse_tfrecord_tf(record):
features = tf.parse_single_example(record, features={ 'shape': tf.FixedLenFeature([3], tf.int64), 'data': tf.FixedLenFeature([], tf.string)})
data = tf.decode_raw(features['data'], tf.uint8)
return tf.reshape(data, features['shape'])
def parse_tfrecord_file(tfr_file, num_threads=8):
dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
return dset
def init_dataset(dset):
tf_iterator = tf.data.Iterator.from_structure(dset.output_types, dset.output_shapes)
tf_init_ops = tf_iterator.make_initializer(dset)
return tf_iterator, tf_init_ops
from tensorflow.python.framework import errors_impl
def iterate_stylegan_records(tfr_file, session=None):
if session is None:
session = tf.get_default_session()
dset = parse_tfrecord_file(tfr_file)
it, init = init_dataset(dset)
session.run(init)
with session.graph.as_default():
op = tf.transpose(it.get_next(), [1, 2, 0])
try:
while True:
yield sess.run(op)
except errors_impl.OutOfRangeError:
pass
def _is_png(filename):
return filename.endswith('.png') or filename.endswith('.PNG')
def _is_png_data(image_data):
return image_data.startswith(b'\x89PNG')
def _is_cmyk(filename):
return False # TODO
def get_coder(coder=None):
coder = me.coder = me.coder if hasattr(me, 'coder') else ImageCoder()
return coder
def _process_image(filename, coder=None):
"""Process a single image file.
Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
coder = get_coder(coder)
# Read the image file.
with tf.gfile.FastGFile(filename, 'rb') as f:
image_data = f.read()
# Clean the dirty data.
if _is_png_data(image_data):
# 1 image is a PNG.
tf.logging.info('Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
elif _is_cmyk(filename):
# 22 JPEG images are in CMYK colorspace.
tf.logging.info('Converting CMYK to RGB for %s' % filename)
image_data = coder.cmyk_to_rgb(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3
return image_data, height, width, image
def tf_randi(*args, **kws):
assert len(args) > 0
if len(args) == 1:
lo, hi = [0] + [x for x in args]
else:
lo, hi = args
return tf.random.uniform((), minval=lo, maxval=hi, dtype=tf.int32, **kws)
def tf_rand(*args, **kws):
if len(args) == 0:
lo, hi = 0.0, 1.0
elif len(args) == 1:
lo, hi = [0] + [x for x in args]
else:
lo, hi = args
return tf.random.uniform((), minval=lo, maxval=hi, **kws)
def tf_biased_rand(*args, bias=3, **kws):
x = tf_rand(*args, **kws)
# simple technique to bias the result towards the center.
for i in range(bias-1):
x += tf_rand(*args, **kws)
dtype = kws.pop('dtype') if 'dtype' in kws else tf.float32
x = tf.cast(x, tf.float32) / bias
x = tf.cast(x, dtype)
return x
def tf_between(*args, bias=3, **kws):
if bias <= 0:
return tf_randi(*args, **kws)
else:
return tf_biased_rand(*args, dtype=tf.int32, bias=bias, **kws)
def random_crop(image_bytes, scope=None, resize=None, method=tf.image.ResizeMethod.AREA, seed=None):
with tf.name_scope(scope, 'random_crop', [image_bytes]):
shape = tf.image.extract_jpeg_shape(image_bytes)
w = shape[0]
h = shape[1]
channels = shape[2]
x, y = 0, 0
n = 3
image = tf.cond(w > h,
lambda: tf.image.decode_and_crop_jpeg(image_bytes, tf.stack([x + tf_between(w - h, seed=seed), y, h, h]), channels=n),
lambda: tf.cond(h > w,
lambda: tf.image.decode_and_crop_jpeg(image_bytes, tf.stack([x, y + tf_between(h - w, seed=seed), w, w]), channels=n),
lambda: tf.image.decode_jpeg(image_bytes, channels=n)))
if resize:
image_size = [resize, resize] if isinstance(resize, int) or isinstance(resize, float) else resize
image = tf.image.resize([image], image_size, method=method)[0]
return image
# with open('test.jpg', 'wb') as f: f.write(tf.get_default_session().run(tf.io.encode_jpeg(sess.run(random_crop(open(np.random.choice(list(glob('*.jpg'))), 'rb').read())))))
IMAGE_SIZE = 224
CROP_PADDING = 32
def distorted_bounding_box_crop(image_bytes,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image_bytes: `Tensor` of binary image data.
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
where each coordinate is [0, 1) and the coordinates are arranged
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
area of the image must contain at least this fraction of any bounding
box supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image
must contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
scope: Optional `str` for name scope.
Returns:
cropped image `Tensor`
"""
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
shape = tf.image.extract_jpeg_shape(image_bytes)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
return image
def _at_least_x_are_equal(a, b, x):
"""At least `x` of `a` and `b` `Tensors` are equal."""
match = tf.equal(a, b)
match = tf.cast(match, tf.int32)
return tf.greater_equal(tf.reduce_sum(match), x)
def _decode_and_random_crop(image_bytes, image_size, resize=True, method=tf.image.ResizeMethod.AREA, aspect_ratio_range=(3. / 4, 4. / 3), area_range=(0.08, 1.0)): # AREA is much higher quality than BICUBIC
"""Make a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = distorted_bounding_box_crop(
image_bytes,
bbox,
min_object_covered=0.1,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=10,
scope=None)
original_shape = tf.image.extract_jpeg_shape(image_bytes)
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
if resize:
image = tf.cond(
bad,
lambda: _decode_and_center_crop(image_bytes, image_size, resize=resize, method=method),
lambda: tf.image.resize([image], [image_size, image_size], method=method)[0])
else:
image = tf.cond(
bad,
lambda: _decode_and_center_crop(image_bytes, image_size, resize=resize, method=method),
lambda: image)
return image
def _decode_and_center_crop(image_bytes, image_size, resize=True, method=tf.image.ResizeMethod.AREA): # AREA is much higher quality than BICUBIC
"""Crops to center of image with padding then scales image_size."""
shape = tf.image.extract_jpeg_shape(image_bytes)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
((image_size / (image_size + CROP_PADDING)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)),
tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = tf.stack([offset_height, offset_width,
padded_center_crop_size, padded_center_crop_size])
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
if resize:
image = tf.image.resize([image], [image_size, image_size], method=method)[0]
return image
import os
from tensorflow.compat.v1.distribute.cluster_resolver import TPUClusterResolver
def get_target(target=None):
if target is None and 'COLAB_TPU_ADDR' in os.environ:
target = os.environ['COLAB_TPU_ADDR']
if target is None and 'TPU_NAME' in os.environ:
target = os.environ['TPU_NAME']
if not target.startswith('grpc://'):
target = TPUClusterResolver(target).get_master()
if target is not None:
print('Using target %s' % target)
return target
def main():
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
parser = argparse.ArgumentParser()
parser.add_argument('infile')
parser.add_argument('outfile')
parser.add_argument('-r', '--resize', type=int, default=0)
args = me.args = parser.parse_args()
with tf.Session(get_target()) as sess:
image_data, width, height, image = _process_image(args.infile)
image_out = sess.run(tf.io.encode_jpeg(sess.run(random_crop(image_data, resize=args.resize))))
with open(args.outfile, 'wb') as f:
f.write(image_out)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment