Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sdabbour-stratio/de2957aa9b3e86a19fe721be1b784071 to your computer and use it in GitHub Desktop.
Save sdabbour-stratio/de2957aa9b3e86a19fe721be1b784071 to your computer and use it in GitHub Desktop.
package com.stratio.governance.unstructured.rcn.inference
import com.johnsnowlabs.nlp.DocumentAssembler
import com.johnsnowlabs.nlp.annotator._
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
// Needs one argument, the path to the NER model
object NerRcnModelSavePipeline {
def main(args: Array[String]): Unit = {
val spark:SparkSession=
SparkSession.
builder().appName("create-pipeline").master("local[4]")
.config("spark.driver.memory","6G").config("spark.kryoserializer.buffer.max","200M")
.config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate()
val document = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")
val sentence = new SentenceDetector()
.setInputCols("document")
.setOutputCol("sentence")
val token = new Tokenizer()
.setInputCols("sentence")
.setOutputCol("token")
val glove_embeddings = BertEmbeddings.pretrained(
"bert_multi_cased", lang = "xx")
.setInputCols("document", "token")
.setOutputCol("embeddings")
val loaded_ner_model = NerDLModel.load(args(0))
.setInputCols("sentence", "token", "embeddings")
.setOutputCol("ner")
.setIncludeConfidence(true)
val converter = new NerConverter()
.setInputCols("document", "token", "ner")
.setOutputCol("ner_span")
val pipeline = new Pipeline()
.setStages(
Array(
document,
sentence,
token,
glove_embeddings,
loaded_ner_model,
converter
)
)
val text = "Mi nombre es Felipe Alvarez Angulo y mi dirección es Avd Cerro Milano 143"
val empty_data = spark.createDataFrame(Seq(
(0, text)
)).toDF("id","text")
val prediction_model = pipeline.fit(empty_data)
val pathToModel = args(0)
var pathToPipeline = args(0) +"_pipeline"
if (pathToModel.substring(
pathToModel.length-1,
args(0).length).equals("\\")) {
pathToPipeline = pathToModel.substring(0,pathToModel.length-1)+"_pipeline"
}
prediction_model.write.overwrite().save(pathToPipeline)
spark.stop()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment