Last active
September 5, 2020 14:16
-
-
Save satylogin/0011d47528cfc85f2ac4309b23ef7b89 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer; | |
import com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer; | |
import org.apache.spark.sql.Row; | |
import scala.util.parsing.json._; | |
import org.apache.spark.sql.types.{DoubleType, StructField, StructType, StringType} | |
class TFRequestRowSerializer (val featureLength: Int) extends RequestRowSerializer { | |
override val contentType: String = "text/csv"; | |
override def serializeRow(row: Row): Array[Byte] = { | |
val featureVector: Array[Double] = new Array[Double](featureLength); | |
for ((x, i) <- row.toSeq.slice(0, featureLength).zipWithIndex) { | |
featureVector(i) = x.toString.toDouble; | |
} | |
val request = featureVector.mkString(",") + "\n"; | |
request.getBytes | |
} | |
} | |
class TFResponseRowDeserializer( | |
predictionColumnName: String = "predictions" | |
) extends ResponseRowDeserializer { | |
override val accepts: String = "application/json"; | |
override val schema: StructType = StructType( | |
Array(StructField(predictionColumnName, DoubleType, nullable = false)) | |
) | |
override def deserializeResponse(responseData: Array[Byte]): Iterator[Row] = { | |
val response = JSON.parseFull( | |
new String(responseData) | |
).get.asInstanceOf[Map[String, List[List[Double]]]]; | |
val prediction = response.get("predictions").get; | |
prediction.map(r => Row(r(0))).toIterator | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment