Skip to content

Instantly share code, notes, and snippets.

@mrafayaleem
Created March 7, 2019 21:48
Show Gist options
  • Save mrafayaleem/b36bd6cc28efe37cf7a7db8b5c47f213 to your computer and use it in GitHub Desktop.
Save mrafayaleem/b36bd6cc28efe37cf7a7db8b5c47f213 to your computer and use it in GitHub Desktop.
# Test dataframe
testDF = NNImageReader.readImages(test_path, sc, resizeH=300, resizeW=300, image_codec=1)
testDF = testDF.withColumn('filename', getFileName('image')).withColumn('label', getLabel('image'))
testPredDF = antbeeModel.transform(testDF).cache()
row = testPredDF.first().asDict()
# showImage function
def showImage(row):
# Open file
plt.imshow(Image.open(row['image'][0][5:]))
# Map prediction to class
title = 'ants' if row['prediction'] == 1.0 else 'bees'
plt.title(title)
showImage(row)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment