Skip to content

Instantly share code, notes, and snippets.

@hayden-donnelly
Created March 20, 2024 21:30
Show Gist options
  • Save hayden-donnelly/9f956921e2d71bb450e632467782209c to your computer and use it in GitHub Desktop.
Save hayden-donnelly/9f956921e2d71bb450e632467782209c to your computer and use it in GitHub Desktop.
Script to convert a directory of images to a collection of Apache Parquet files with HuggingFace metadata.
# Example usage:
# python images_to_hf_parquet.py --input ./base_image_directory/ --output ./parquet_output_directory/ --samples_per_file 10000
import pyarrow as pa
import pyarrow.parquet as pq
from PIL import Image
import os, io, json, glob, argparse
def save_table(image_data, table_number, output_path, zfill_amount):
print(f'Entries in table {table_number}: {len(image_data)}')
schema = pa.schema(
fields=[
('image', pa.struct([('bytes', pa.binary()), ('path', pa.string())]))
],
metadata={
b'huggingface': json.dumps({
'info': {
'features': {
'image': {'_type': 'Image'}
}
}
}).encode('utf-8')
}
)
table = pa.Table.from_pylist(image_data, schema=schema)
pq.write_table(table, f'{output_path}/{str(table_number).zfill(zfill_amount)}.parquet')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--samples_per_file', type=int, required=True)
parser.add_argument('--extension', type=str, default='png')
parser.add_argument('--total_samples', type=int, default=-1)
parser.add_argument('--zfill', type=int, default=4)
args = parser.parse_args()
assert args.zfill > 0, f'zfill must be greater than 0.'
assert not args.extension.startswith('.'), (
'Extension should not start with a dot, for example, .png should just be png'
)
if not os.path.exists(args.output):
os.makedirs(args.output)
glob_end = f'**/*.{args.extension}'
if args.input.endswith('/'):
glob_pattern = f'{args.input}{glob_end}'
else:
glob_pattern = f'{args.input}/{glob_end}'
paths = glob.glob(glob_pattern, recursive=True)
print(f'Found {len(paths)} files.')
image_data = []
samples_in_current_file = 0
current_file_number = 0
for i, path in enumerate(paths):
if samples_in_current_file >= args.samples_per_file:
save_table(image_data, current_file_number, args.output, args.zfill)
image_data = []
samples_in_current_file = 0
current_file_number += 1
samples_in_current_file += 1
with Image.open(path) as image:
image_bytes = io.BytesIO()
image.save(image_bytes, format='PNG')
image_dict = {
'image': {
'bytes': image_bytes.getvalue(),
'path': f'{i}.{args.extension}'
}
}
image_data.append(image_dict)
if args.total_samples != -1 and i >= args.total_samples:
print('Reached max sample count')
break
save_table(image_data, current_file_number, args.output, args.zfill)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment