Skip to content

Instantly share code, notes, and snippets.

@bgweber
Created May 15, 2019 03:26
Show Gist options
  • Save bgweber/5507d4595bca221b3e480d3dae510631 to your computer and use it in GitHub Desktop.
Save bgweber/5507d4595bca221b3e480d3dae510631 to your computer and use it in GitHub Desktop.
# pull all data to the driver node
sample_df = spark_df.toPandas()
# create a prediction for each user
ids = sample_df['user_id']
x_train = sample_df.drop(['label', 'user_id', 'partition_id'], axis=1)
pred = model.predict_proba(x_train)
result_df = pd.DataFrame({'user_id': ids, 'prediction': pred[:,1]})
# display the results
display(spark.createDataFrame(result_df))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment