Created
September 28, 2020 21:17
-
-
Save zeryx/4a14d4deb4d8dc2e7f8d1ee01c74bfd0 to your computer and use it in GitHub Desktop.
MLeap runtime project, for running a Spark model on Algorithmia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.algorithmia | |
import com.algorithmia.handler.AbstractAlgorithm | |
import ml.combust.bundle.BundleFile | |
import ml.combust.bundle.dsl.Bundle | |
import ml.combust.mleap.core.types._ | |
import ml.combust.mleap.runtime.MleapSupport._ | |
import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row, Transformer} | |
import scala.collection.mutable | |
import scala.util.{Failure, Success, Try} | |
class Algorithm extends AbstractAlgorithm[InputExample, String] { | |
//To ensure that we have a mutable type to put our mleap bundle into, | |
//we create a hashmap to record and store our model. | |
var loaded_state = new mutable.HashMap[String, Bundle[Transformer]]() | |
val model_uri = "data://algorithmiahq/mleap/simple-spark-pipeline.zip" | |
// If you want to run this algorithm locally, add an Algorithmia API key into the constructor of the Algorithmia.Client below. | |
val client: AlgorithmiaClient = Algorithmia.client() | |
override def load(): Try[Unit] = { | |
val datafile_path = this.client.file(this.model_uri).getFile.getPath | |
val real_path = s"jar:file:$datafile_path" | |
//Now that we have the model downloaded, lets start the process of loading the bundle into the MLeap runtime. | |
val bundleFile: Try[Bundle[Transformer]] = BundleFile(real_path) | |
.loadMleapBundle() | |
bundleFile match { | |
case Failure(exception) => return Failure(exception) | |
case Success(value) => this.loaded_state.put("model", value) | |
} | |
Success() | |
} | |
override def apply(input: InputExample): Try[String] = { | |
this.load() | |
val schema = StructType(StructField("test_string", ScalarType.String), | |
StructField("test_double", ScalarType.Double)).get | |
val data: Seq[Row] = input.rows.map(l => Row(l.field_name, l.value)) | |
val frame = DefaultLeapFrame(schema, data) | |
val bundle: Bundle[Transformer] = this.loaded_state.get("model").head | |
val mleapPipeline = bundle.root | |
val frame2 = mleapPipeline.transform(frame).get | |
val data2 = frame2.dataset | |
Console.println(data2) | |
Success("Hello " + data2) | |
} | |
} | |
object Algorithm { | |
val handler = Algorithmia.handler(new Algorithm) | |
def main(args: Array[String]): Unit = { | |
handler.serve() | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.algorithmia | |
import play.api.libs.json._ | |
// for JSON serialization/deserialization to work, you'll need to not only create case classes for your input and output types, | |
// but also companion objects with implicit reads/writes depending on if we're returning the type, or ingesting this type. | |
case class InputExample(rows: List[InputRow]) | |
object InputExample{ | |
implicit val reads: Reads[InputExample] = Json.reads[InputExample] | |
} | |
case class InputRow(field_name: String, value: Double) | |
object InputRow{ | |
implicit val reads: Reads[InputRow] = Json.reads[InputRow] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment