Skip to content

Instantly share code, notes, and snippets.

@shcheklein
Created December 19, 2022 20:34
Show Gist options
  • Save shcheklein/d28e23b391041535b781b067792e1528 to your computer and use it in GitHub Desktop.
Save shcheklein/d28e23b391041535b781b067792e1528 to your computer and use it in GitHub Desktop.
Download LAION images with EMR
from img2dataset import download
import shutil
import os
from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel
from pyspark import SparkConf, SparkContext
# Preparations:
#
# - build https://github.com/rom1504/img2dataset on the EMR machine with `make build-pex`
# - put `img2dataset.pex` (with pscp, pssh and IPs from AWS console) into `/home/hadoop`
# - put this file into `/home/hadoop` on all machines
# - make sure that `output_dir` and `url_list` are set correctly and EMR has access to
# to those locations
#
# Usage:
#
# spark-submit --conf "spark.pyspark.python=./img2dataset.pex" --files ./img2dataset.pex --master yarn download.py
def create_spark_session():
# this must be a path that is available on all worker nodes
pex_file = "/home/hadoop/img2dataset.pex"
os.environ['PYSPARK_PYTHON'] = pex_file
# this config is done for 10 modes of c6i.4xlarge instances EMR
spark = (
SparkSession.builder
.config("spark.submit.deployMode", "client") \
.config("spark.executorEnv.PEX_ROOT", "./.pex")
.config("spark.executor.cores", "16")
.config("spark.cores.max", "200")
.config("spark.executor.memory", "16GB")
.config("spark.executor.memoryOverhead", "4GB")
.config("spark.task.maxFailures", "100")
.master("yarn")
.appName("LAION")
.getOrCreate()
)
return spark
output_dir = "s3://dvc-private/laion/output"
url_list = "s3://dvc-private/laion/metadata/laion2B-en/"
spark = create_spark_session()
download(
processes_count=1, # this is not used with spark, instead one task for each core will be started (nb executor * nb core per executor)
thread_count=32,
retries=0,
url_list = url_list,
image_size=256,
resize_only_if_bigger=True,
resize_mode="keep_ratio",
skip_reencode=True,
output_folder=output_dir,
output_format="files",
input_format="parquet",
url_col="URL",
caption_col="TEXT",
distributor="pyspark",
number_sample_per_shard=10000,
oom_shard_count=6,
incremental_mode="incremental",
save_additional_columns=["NSFW","similarity","LICENSE"]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment