Skip to content

Instantly share code, notes, and snippets.

@ialhashim
Created January 15, 2023 07:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ialhashim/ffd1c54645f6295e17aa99ed3cc5bef7 to your computer and use it in GitHub Desktop.
Save ialhashim/ffd1c54645f6295e17aa99ed3cc5bef7 to your computer and use it in GitHub Desktop.
Split images into square patches
def split_to_batches(img, trained_size=512):
img = skimage.img_as_float32(img)
direction = 'landscape'
if img.shape[0] > img.shape[1]: direction = 'portrait'
# resize into 512x[width|height]
width, height = img.shape[1], img.shape[0]
if direction == 'landscape':
new_height = trained_size
new_width = int(new_height * (width / height))
elif direction == 'portrait':
new_width = trained_size
new_height = int(new_width * (height / width))
img = resize(img, (new_height, new_width), preserve_range=True, anti_aliasing=True)
# Collect batches covering resized image
if direction == 'landscape':
batch_count = math.ceil(new_width / trained_size)
elif direction == 'portrait':
batch_count = math.ceil(new_height / trained_size)
batches = []
for i in range(batch_count):
x, y = 0, 0
if direction == 'landscape': x = min(img.shape[1], (i * trained_size))
if direction == 'portrait': y = min(img.shape[0], (i * trained_size))
y = min(y, img.shape[0] - trained_size)
x = min(x, img.shape[1] - trained_size)
batch = img[y:y+trained_size, x:x+trained_size, :]
batches.append( ((x,y), batch) )
return direction, batches, img
def predict_in_batches(img):
direction, batches, new_input_img = split_to_batches(img)
# predict on image splits
batch_images = []
for ii in batches:
pos, im = ii
batch_images.append( im )
batch = (np.stack(batch_images) * 255).astype('uint8')
batch_hat = model.predict(batch)
if len(batch_hat.shape) < 4:
batch_hat = np.expand_dims(batch_hat, 0)
# combine predictions
prediction_img = new_input_img
for i in range(len(batches)):
ii = batches[i]
pos, im = ii
x,y = pos
prediction_img[ y : y+im.shape[0], x : x+im.shape[1] , : ] = batch_hat[i]
prediction_img = (resize(prediction_img, (img.shape[0], img.shape[1]), preserve_range=True, anti_aliasing=True) * 255).astype('uint8')
img = (skimage.img_as_float32(img) * 255).astype('uint8')
lab_prediction_img = color.rgb2lab(prediction_img)
lab_img = color.rgb2lab(img)
lab_prediction_img[:,:,0] = lab_img[:,:,0]
prediction_img = (color.lab2rgb(lab_prediction_img) * 255).astype('uint8')
return prediction_img
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment