Skip to content

Instantly share code, notes, and snippets.

@sdabbour-stratio
Created September 15, 2020 16:30
Show Gist options
  • Save sdabbour-stratio/bc26fd157ede65d7d230b9d800627797 to your computer and use it in GitHub Desktop.
Save sdabbour-stratio/bc26fd157ede65d7d230b9d800627797 to your computer and use it in GitHub Desktop.
package com.stratio.governance.unstructured.rcn.inference
import java.io.File
import com.johnsnowlabs.nlp.annotator.WordEmbeddingsModel
import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach
import com.johnsnowlabs.nlp.embeddings.BertEmbeddings
import com.johnsnowlabs.nlp.training.CoNLL
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.LocalFileSystem
import org.apache.hadoop.hdfs.DistributedFileSystem
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
object Trainer {
def main(args: Array[String]): Unit = {
if (args.length<5) {
System.out.println("SYNTAX: java -cp <jar_name>.jar NerRcnSparkTrain <training_filename> <graphFolder> <model_filename> <driverMemory> <numEpochs> ['CREATE_PIPELINE']")
//System.exit(0)
}
//Path of the PipelineModel.
val training_filename = "hdfs://hdfs:9000/tmp/test_100k_1.train.txt" //args(0)
val graphFolder = "/tmp/" //args(1)
val model_filename = "model_test" //args(2)
val driverMemory = "16g" //args(3)
val executorMemory = "16g" //args(3)
val numEpochs = 3 //Integer.parseInt(args(4))
//Create Spark Session
val spark:SparkSession=
SparkSession.
builder().appName("test_train")
.master("spark://spark-master:7077")
.config("spark.driver.maxResultSize", "4g")
.config("spark.driver.memory",driverMemory)
.config("spark.memory.fraction","0.9")
.config("spark.executor.memory",executorMemory)
.config("spark.kryoserializer.buffer.max","1000M")
.config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate()
val hadoopConfig: Configuration = spark.sparkContext.hadoopConfiguration
hadoopConfig.set("fs.hdfs.impl", classOf[org.apache.hadoop.hdfs.DistributedFileSystem].getName)
hadoopConfig.set("fs.file.impl", classOf[org.apache.hadoop.fs.LocalFileSystem].getName)
import spark.implicits._
System.out.println("Reading CoNLL training dataset: "+training_filename)
val t1 = System.currentTimeMillis()
val training_data = CoNLL().readDataset(spark, training_filename).repartition(1000) //, $"text"
val glove_embeddings = BertEmbeddings.pretrained(name = "bert_multi_cased", lang = "xx").
setInputCols("document", "token").setOutputCol("embeddings")
val nerTagger = new NerDLApproach().
setInputCols("sentence", "token", "embeddings").
setLabelColumn("label").
setOutputCol("ner").
setMaxEpochs(numEpochs).
setLr(0.001f).
setPo(0.005f).
setBatchSize(10).
setRandomSeed(0).
setVerbose(1).
setValidationSplit(0.2f).
setEvaluationLogExtended(true).
setEnableOutputLogs(true).
setIncludeConfidence(true).
setGraphFolder(graphFolder)
val ner_pipeline = new Pipeline().setStages(
Array(
glove_embeddings,
nerTagger
)
)
val ner_model = ner_pipeline.fit(training_data)
ner_model.write.overwrite().save(model_filename)
val t2 = System.currentTimeMillis()
val duration = t2 - t1
System.out.println(s"## Time: $duration")
val minutes = (duration / 1000) / 60
val seconds = ((duration / 1000) % 60).asInstanceOf[Int]
System.out.println(s"## Execution took: $minutes minutes and $seconds seconds")
spark.stop()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment