Skip to content

Instantly share code, notes, and snippets.

@daniel-j-h
Last active July 10, 2018 12:34
Show Gist options
  • Save daniel-j-h/18b255ad5c82927413ea71bd830c9c51 to your computer and use it in GitHub Desktop.
Save daniel-j-h/18b255ad5c82927413ea71bd830c9c51 to your computer and use it in GitHub Desktop.
Test-time augmentation utility helper for https://github.com/mapbox/robosat
#!/usr/bin/env python3
'''
pip install tqdm pillow mercantile
'''
'''
Simple image rotation script for test-time augmentation.
Usage:
- predict on original slippy map dir with image tiles, save probabilities in probs0 directory
- copy original slippy map dir three times, use this script to rotate by 90, 180, 270, respectively
- predict on three new slippy map dirs, save in probs1, probs2, probs3 directory
- use this script to rotate probabilities back by 270, 180, 90, respectively
- use `rs masks` on probs0 probs1 probs2 probs3; it handles weighted average soft-voting already
'''
import os
import sys
import argparse
import concurrent.futures
import mercantile
from PIL import Image
from tqdm import tqdm
def tiles_from_slippy_map(root):
for z in os.listdir(root):
for x in os.listdir(os.path.join(root, z)):
for name in os.listdir(os.path.join(root, z, x)):
y = os.path.splitext(name)[0]
tile = mercantile.Tile(x=int(x), y=int(y), z=int(z))
path = os.path.join(root, z, x, name)
yield tile, path
def main():
rotations = {90: Image.ROTATE_90, 180: Image.ROTATE_180, 270: Image.ROTATE_270}
parser = argparse.ArgumentParser()
parser.add_argument('degree', type=int, choices=rotations.keys())
parser.add_argument('root', type=str)
parser.add_argument('--threads', type=int, default=1)
args = parser.parse_args()
tiles = list(tiles_from_slippy_map(args.root))
progress = tqdm(total=len(tiles), desc='Rotating', unit='tile', ascii=True)
def worker(tile):
_, path = tile
ok = False
try:
Image.open(path).transpose(rotations[args.degree]).save(path, optimize=True)
ok = True
except:
pass
progress.update()
return tile, ok
with concurrent.futures.ThreadPoolExecutor(args.threads) as executor:
for tile, ok in executor.map(worker, tiles):
if not ok:
print("Warning: {} failed, skipping".format(tile), file=sys.stderr)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment