Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save satylogin/0011d47528cfc85f2ac4309b23ef7b89 to your computer and use it in GitHub Desktop.
Save satylogin/0011d47528cfc85f2ac4309b23ef7b89 to your computer and use it in GitHub Desktop.
import com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer;
import com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer;
import org.apache.spark.sql.Row;
import scala.util.parsing.json._;
import org.apache.spark.sql.types.{DoubleType, StructField, StructType, StringType}
class TFRequestRowSerializer (val featureLength: Int) extends RequestRowSerializer {
override val contentType: String = "text/csv";
override def serializeRow(row: Row): Array[Byte] = {
val featureVector: Array[Double] = new Array[Double](featureLength);
for ((x, i) <- row.toSeq.slice(0, featureLength).zipWithIndex) {
featureVector(i) = x.toString.toDouble;
}
val request = featureVector.mkString(",") + "\n";
request.getBytes
}
}
class TFResponseRowDeserializer(
predictionColumnName: String = "predictions"
) extends ResponseRowDeserializer {
override val accepts: String = "application/json";
override val schema: StructType = StructType(
Array(StructField(predictionColumnName, DoubleType, nullable = false))
)
override def deserializeResponse(responseData: Array[Byte]): Iterator[Row] = {
val response = JSON.parseFull(
new String(responseData)
).get.asInstanceOf[Map[String, List[List[Double]]]];
val prediction = response.get("predictions").get;
prediction.map(r => Row(r(0))).toIterator
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment