Skip to content

Instantly share code, notes, and snippets.

@mayankgrwl97
Created September 20, 2020 08:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mayankgrwl97/34e6ef1091881cb2045bf9aa7dbbf382 to your computer and use it in GitHub Desktop.
Save mayankgrwl97/34e6ef1091881cb2045bf9aa7dbbf382 to your computer and use it in GitHub Desktop.
import argparse
import glob
import multiprocessing
import os
from functools import partial
import cv2
from tqdm import tqdm
def png_to_jpg(img_png_path, jpg_dir):
img = cv2.imread(img_png_path)
img_basename = os.path.splitext(os.path.basename(img_png_path))[0]
img_jpg_path = os.path.join(jpg_dir, img_basename+'.jpg')
cv2.imwrite(img_jpg_path, img)
return img_jpg_path
def process(img_png_paths, jpg_dir, n_workers):
png_to_jpg_fn = partial(png_to_jpg, jpg_dir=jpg_dir)
with multiprocessing.Pool(n_workers) as pool:
for img_jpg_path in tqdm(pool.imap_unordered(png_to_jpg_fn, img_png_paths))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--png_dir', type=str, required=True)
parser.add_argument('--jpg_dir', type=str, required=True)
parser.add_argument('--n_workers', type=int, default=1)
args = parser.parse_args()
assert os.path.exists(args.png_dir) and os.path.isdir(args.png_dir)
os.makedirs(args.jpg_dir, exist_ok=True)
return args
if __name__ == '__main__':
args = get_args()
img_png_paths = sorted(glob.glob(os.path.join(args.png_dir, '*.png')))
process(img_png_paths, args.jpg_dir, args.n_workers)
@nikhilweee
Copy link

You can practically get rid of the process function by using process_map from tqdm.contrib.concurrent.

from itertools import repeat
from tqdm.contrib.concurrent import process_map

process_map(png_to_jpg, img_png_paths, repeat(jpg_dir), max_workers=8)

As an added bonus, it also does the job of co-ordinating a nice progress bar for you.

@nikhilweee
Copy link

Here's my version of a similar script using process_map, along with some bells and whistles like resume, debug and profiling capabilities.

import os
import sys
import time
import argparse
from PIL import Image
from tqdm import tqdm
from itertools import repeat
from tqdm.contrib.concurrent import process_map

def convert(src_path, dst_path, index=None, debug=False, ext='png'):
    start = time.time()
    im = Image.open(src_path)
    end = time.time()

    diff = end - start
    if debug:
        tqdm.write(f'Loaded \t {index}: \t {diff:04.2f}s \t{src_path}')

    start = time.time()
    try:
        im.save(dst_path, ext)
    except Exception as e:
        tqdm.write(F'Failed \t {index}: \t {src_path} \t {str(e)}')
        return
    end = time.time()

    diff = end - start
    if debug:
        tqdm.write(f'Saved \t {index}: \t {diff:04.2f}s \t{dst_path}')


def preprocess(src_dir, dst_dir):

    src_paths = []
    dst_paths = []

    for src_root, dirs, files in os.walk(src_dir, topdown=False):
        if args.debug:
            print(f'Scanning: \t {src_root}')
        for idx, src_file in enumerate(files):
            if not args.debug:
                print(f'Scanning: \t {idx + 1}/{len(files)} \t {src_root}', end='\x1b[1K\r')
            name, ext = os.path.splitext(src_file)
            src_path = os.path.join(src_root, src_file)
            if ext.lower() in ['.tif', '.tiff']:
                dst_file = name + '.' + args.format
                dst_root = src_root.replace(src_dir, dst_dir)
                dst_path = os.path.join(dst_root, dst_file)
                if os.path.isfile(dst_path):
                    try:
                        if args.debug:
                            print(f'Verifying: \t {dst_path}')
                        img = Image.open(dst_path)
                        img.verify()
                        continue
                    except:
                        pass
                if not os.path.isdir(dst_root):
                    os.makedirs(dst_root)
                src_paths.append(src_path)
                dst_paths.append(dst_path)
    if args.debug:
        print(f'Found {len(src_paths)} image files. Processing . . .')
    return src_paths, dst_paths

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--src-dir', required=True, help='source directory to look for images')
    parser.add_argument('--dst-dir', required=True, help='destination directory to store images')
    parser.add_argument('--format', default='jpeg', choices=['jpeg', 'png'], help='image format to save')
    parser.add_argument('--debug', action='store_true', help='run the script in debug mode')
    args = parser.parse_args()

    src_paths, dst_paths = preprocess(args.src_dir, args.dst_dir)
    indices = list(range(1, len(src_paths) + 1))
    # Sequential
    if args.debug:
        for src_path, dst_path, index in zip(src_paths, dst_paths, indices):
            convert(src_path, dst_path, index, args.debug, args.format)
    # Parallel
    else:
        process_map(convert, src_paths, dst_paths, indices, repeat(args.debug),
                    repeat(args.format), desc='Processing: ', chunksize=1)

@mayankgrwl97
Copy link
Author

In order to run multiprocessing on a torch model, replace

from multiprocessing import Process, Pool

with

from torch.multiprocessing import Pool, Process, set_start_method
try:
     set_start_method('spawn')
except RuntimeError:
    pass

Reference: https://stackoverflow.com/questions/48822463/how-to-use-pytorch-multiprocessing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment