Skip to content

Instantly share code, notes, and snippets.

@mrafayaleem
Created March 7, 2019 18:57
Show Gist options
  • Save mrafayaleem/ac78873f9d1c5950e3899e2bd405b5de to your computer and use it in GitHub Desktop.
Save mrafayaleem/ac78873f9d1c5950e3899e2bd405b5de to your computer and use it in GitHub Desktop.
# Define udfs to extract filename and generate labels in floats
getFileName = udf(lambda row: os.path.basename(row[0]), StringType())
getLabel = udf(lambda row: 1.0 if 'ants' in row[0] else 2.0, DoubleType())
# Construct training dataframe
trainingDF = NNImageReader.readImages(train_path, sc, resizeH=300, resizeW=300, image_codec=1)
trainingDF = trainingDF.withColumn('filename', getFileName('image')).withColumn('label', getLabel('image'))
# Construct validation dataframe
validationDF = NNImageReader.readImages(val_path, sc, resizeH=300, resizeW=300, image_codec=1)
validationDF = validationDF.withColumn('filename', getFileName('image')).withColumn('label', getLabel('image'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment