Skip to content

Instantly share code, notes, and snippets.

@afranzi
Last active October 25, 2018 21:10
Show Gist options
  • Save afranzi/8cf86671470ee176e6b0b30929c11d42 to your computer and use it in GitHub Desktop.
Save afranzi/8cf86671470ee176e6b0b30929c11d42 to your computer and use it in GitHub Desktop.
MLflow UDFs from Scala Spark
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FirstAtRe = ^_\n",
"AliasRe = [\\s_.:@]+\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"getFieldAlias: (field_name: String)String\n",
"selectFieldsNormalized: (columns: List[String])(df: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n",
"normalizeSchema: (df: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[\\s_.:@]+"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import org.apache.spark.sql.functions.col\n",
"import org.apache.spark.sql.types.StructType\n",
"import org.apache.spark.sql.{Column, DataFrame}\n",
"import scala.util.matching.Regex\n",
"\n",
"val FirstAtRe: Regex = \"^_\".r\n",
"val AliasRe: Regex = \"[\\\\s_.:@]+\".r\n",
"\n",
"def getFieldAlias(field_name: String): String = {\n",
" FirstAtRe.replaceAllIn(AliasRe.replaceAllIn(field_name, \"_\"), \"\")\n",
"}\n",
"\n",
"def selectFieldsNormalized(columns: List[String])(df: DataFrame): DataFrame = {\n",
" val fieldsToSelect: List[Column] = columns.map(field =>\n",
" col(field).as(getFieldAlias(field))\n",
" )\n",
" df.select(fieldsToSelect: _*)\n",
"}\n",
"\n",
"def normalizeSchema(df: DataFrame): DataFrame = {\n",
" val schema = df.columns.toList\n",
" df.transform(selectFieldsNormalized(schema))\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"winePath = ~/Research/mlflow-workshop/examples/wine_quality/data/winequality-red.csv\n",
"modelPath = /tmp/mlflow/artifactStore/0/96cba14c6e4b452e937eb5072467bf79/artifacts/model\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"/tmp/mlflow/artifactStore/0/96cba14c6e4b452e937eb5072467bf79/artifacts/model"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val winePath = \"~/Research/mlflow-workshop/examples/wine_quality/data/winequality-red.csv\"\n",
"val modelPath = \"/tmp/mlflow/artifactStore/0/96cba14c6e4b452e937eb5072467bf79/artifacts/model\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"df = [fixed_acidity: string, volatile_acidity: string ... 10 more fields]\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[fixed_acidity: string, volatile_acidity: string ... 10 more fields]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val df = spark.read\n",
" .format(\"csv\")\n",
" .option(\"header\", \"true\")\n",
" .option(\"delimiter\", \";\")\n",
" .load(winePath)\n",
" .transform(normalizeSchema)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function spark_udf.<locals>.predict at 0x1116a98c8>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%PySpark\n",
"import mlflow\n",
"from mlflow import pyfunc\n",
"\n",
"model_path = \"/tmp/mlflow/artifactStore/0/96cba14c6e4b452e937eb5072467bf79/artifacts/model\"\n",
"wine_quality_udf = mlflow.pyfunc.spark_udf(spark, model_path)\n",
"\n",
"spark.udf.register(\"wineQuality\", wine_quality_udf)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df.createOrReplaceTempView(\"wines\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"+-------+------------------+\n",
"|quality| prediction|\n",
"+-------+------------------+\n",
"| 5| 5.576883967129615|\n",
"| 5| 5.50664776916154|\n",
"| 5| 5.525504822954496|\n",
"| 6| 5.504311247097457|\n",
"| 5| 5.576883967129615|\n",
"| 5|5.5556903912725755|\n",
"| 5| 5.467882654744997|\n",
"| 7| 5.710602976324739|\n",
"| 7| 5.657319539336507|\n",
"| 5| 5.345098606538708|\n",
"+-------+------------------+\n",
"\n"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%SQL\n",
"SELECT \n",
" quality,\n",
" wineQuality(\n",
" fixed_acidity,\n",
" volatile_acidity,\n",
" citric_acid,\n",
" residual_sugar,\n",
" chlorides,\n",
" free_sulfur_dioxide,\n",
" total_sulfur_dioxide,\n",
" density,\n",
" pH,\n",
" sulphates,\n",
" alcohol\n",
" ) AS prediction\n",
"FROM wines\n",
"LIMIT 10"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+--------+-----------+---------+-----------+\n",
"|name |database|description|className|isTemporary|\n",
"+-----------+--------+-----------+---------+-----------+\n",
"|wineQuality|null |null |null |true |\n",
"+-----------+--------+-----------+---------+-----------+\n",
"\n"
]
}
],
"source": [
"spark.catalog.listFunctions.filter('name like \"%wineQuality%\").show(20, false)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Apache Toree - Scala",
"language": "scala",
"name": "apache_toree_scala"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".scala",
"mimetype": "text/x-scala",
"name": "scala",
"pygments_lexer": "scala",
"version": "2.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment