Created
September 15, 2020 16:30
-
-
Save sdabbour-stratio/bc26fd157ede65d7d230b9d800627797 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.stratio.governance.unstructured.rcn.inference | |
import java.io.File | |
import com.johnsnowlabs.nlp.annotator.WordEmbeddingsModel | |
import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach | |
import com.johnsnowlabs.nlp.embeddings.BertEmbeddings | |
import com.johnsnowlabs.nlp.training.CoNLL | |
import org.apache.hadoop.conf.Configuration | |
import org.apache.hadoop.fs.LocalFileSystem | |
import org.apache.hadoop.hdfs.DistributedFileSystem | |
import org.apache.spark.ml.Pipeline | |
import org.apache.spark.sql.SparkSession | |
object Trainer { | |
def main(args: Array[String]): Unit = { | |
if (args.length<5) { | |
System.out.println("SYNTAX: java -cp <jar_name>.jar NerRcnSparkTrain <training_filename> <graphFolder> <model_filename> <driverMemory> <numEpochs> ['CREATE_PIPELINE']") | |
//System.exit(0) | |
} | |
//Path of the PipelineModel. | |
val training_filename = "hdfs://hdfs:9000/tmp/test_100k_1.train.txt" //args(0) | |
val graphFolder = "/tmp/" //args(1) | |
val model_filename = "model_test" //args(2) | |
val driverMemory = "16g" //args(3) | |
val executorMemory = "16g" //args(3) | |
val numEpochs = 3 //Integer.parseInt(args(4)) | |
//Create Spark Session | |
val spark:SparkSession= | |
SparkSession. | |
builder().appName("test_train") | |
.master("spark://spark-master:7077") | |
.config("spark.driver.maxResultSize", "4g") | |
.config("spark.driver.memory",driverMemory) | |
.config("spark.memory.fraction","0.9") | |
.config("spark.executor.memory",executorMemory) | |
.config("spark.kryoserializer.buffer.max","1000M") | |
.config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate() | |
val hadoopConfig: Configuration = spark.sparkContext.hadoopConfiguration | |
hadoopConfig.set("fs.hdfs.impl", classOf[org.apache.hadoop.hdfs.DistributedFileSystem].getName) | |
hadoopConfig.set("fs.file.impl", classOf[org.apache.hadoop.fs.LocalFileSystem].getName) | |
import spark.implicits._ | |
System.out.println("Reading CoNLL training dataset: "+training_filename) | |
val t1 = System.currentTimeMillis() | |
val training_data = CoNLL().readDataset(spark, training_filename).repartition(1000) //, $"text" | |
val glove_embeddings = BertEmbeddings.pretrained(name = "bert_multi_cased", lang = "xx"). | |
setInputCols("document", "token").setOutputCol("embeddings") | |
val nerTagger = new NerDLApproach(). | |
setInputCols("sentence", "token", "embeddings"). | |
setLabelColumn("label"). | |
setOutputCol("ner"). | |
setMaxEpochs(numEpochs). | |
setLr(0.001f). | |
setPo(0.005f). | |
setBatchSize(10). | |
setRandomSeed(0). | |
setVerbose(1). | |
setValidationSplit(0.2f). | |
setEvaluationLogExtended(true). | |
setEnableOutputLogs(true). | |
setIncludeConfidence(true). | |
setGraphFolder(graphFolder) | |
val ner_pipeline = new Pipeline().setStages( | |
Array( | |
glove_embeddings, | |
nerTagger | |
) | |
) | |
val ner_model = ner_pipeline.fit(training_data) | |
ner_model.write.overwrite().save(model_filename) | |
val t2 = System.currentTimeMillis() | |
val duration = t2 - t1 | |
System.out.println(s"## Time: $duration") | |
val minutes = (duration / 1000) / 60 | |
val seconds = ((duration / 1000) % 60).asInstanceOf[Int] | |
System.out.println(s"## Execution took: $minutes minutes and $seconds seconds") | |
spark.stop() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment