Skip to content

Instantly share code, notes, and snippets.

@FavioVazquez
Created May 2, 2018 14:57
Show Gist options
  • Save FavioVazquez/84b0201f2ec0cbfc64fa3736bc7a76b5 to your computer and use it in GitHub Desktop.
Save FavioVazquez/84b0201f2ec0cbfc64fa3736bc7a76b5 to your computer and use it in GitHub Desktop.
from pyspark.ml.image import ImageSchema
from pyspark.sql.functions import lit
from sparkdl.image import imageIO
tulips_df = ImageSchema.readImages("flower_photos/tulips").withColumn("label", lit(1))
daisy_df = imageIO.readImagesWithCustomFn("flower_photos/daisy", decode_f=imageIO.PIL_decode).withColumn("label", lit(0))
tulips_train, tulips_test, _ = tulips_df.randomSplit([0.1, 0.05, 0.85]) # use larger training sets (e.g. [0.6, 0.4] for getting more images)
daisy_train, daisy_test, _ = daisy_df.randomSplit([0.1, 0.05, 0.85]) # use larger training sets (e.g. [0.6, 0.4] for getting more images)
train_df = tulips_train.unionAll(daisy_train)
test_df = tulips_test.unionAll(daisy_test)
# Under the hood, each of the partitions is fully loaded in memory, which may be expensive.
# This ensure that each of the paritions has a small size.
train_df = train_df.repartition(100)
test_df = test_df.repartition(100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment