Last active
June 9, 2020 11:04
-
-
Save weinino/2b8796acc271dab8df4f3d5e6e807191 to your computer and use it in GitHub Desktop.
Fine-tuning BERT model in DL4J
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Model imported: C:/workspaces/bert/data/multi_cased_L-12_H-768_A-12/frozen/bert_frozen_mb4_len512.pb | |
*/ | |
// Constants | |
public final static String PREPENDED_TOKEN = "[CLS]"; // token inserted at the beginning of each sentence | |
public final static int BATCH_SIZE = 1; | |
public final static int MAX_SENTENCE_LENGTH = 128; | |
public final static List<String> LABELS = Arrays.asList("A", "B"); | |
public void train(IResettableIterator<InputOutputTuple<MlText, MlDecimalVector>> trainDataIterator, IResettableIterator<InputOutputTuple<MlText, MlDecimalVector>> testDataIterator) throws IOException { | |
SameDiff bertModel = TFGraphMapper.importGraph(BertSentenceClassification.class.getResourceAsStream("bert_frozen_mb4_len512.pb")); | |
bertModel.getVariable("Placeholder").setShape(BATCH_SIZE, MAX_SENTENCE_LENGTH); | |
bertModel.getVariable("Placeholder_1").setShape(BATCH_SIZE, MAX_SENTENCE_LENGTH); | |
bertModel.getVariable("Placeholder_2").setShape(BATCH_SIZE, MAX_SENTENCE_LENGTH); | |
bertModel.renameVariable("Placeholder", "tokenIdxs"); | |
bertModel.renameVariable("Placeholder_1", "mask"); | |
bertModel.renameVariable("Placeholder_2", "sentenceIdx"); //only ever 0, but needed by this model... | |
Set<String> floatConstants = new HashSet<>(Arrays.asList( | |
"bert/encoder/ones")); | |
//For training, convert weights and biases from constants to variables: | |
for (SDVariable v : bertModel.variables()) { | |
if (v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.name()) && !v.name().startsWith("bert/embeddings")) { //Skip scalars - trainable params | |
v.convertToVariable(); | |
} | |
} | |
int sizeOfClassificationLayer = 768; | |
int numberOfClasses = LABELS.size(); | |
SDVariable weights = bertModel.var("output_weights", new XavierInitScheme('c', sizeOfClassificationLayer, numberOfClasses), DataType.FLOAT, sizeOfClassificationLayer, numberOfClasses); | |
SDVariable bias = bertModel.var("output_bias", new ZeroInitScheme('c'), DataType.FLOAT, numberOfClasses); | |
SDVariable linear = bertModel.mmul("output_dense", bertModel.getVariable("bert/pooler/dense/Tanh"), weights).add(bias); | |
SDVariable softmax = bertModel.nn.softmax("output", linear); | |
//For training, we'll need to add a label placeholder for one-hot labels: | |
SDVariable label = bertModel.placeHolder("label", DataType.FLOAT, 1, numberOfClasses); | |
bertModel.loss().logLoss("loss",label, softmax); | |
//Set up training configuration... | |
bertModel.setTrainingConfig(TrainingConfig.builder() | |
.updater(new Sgd.Builder().learningRate(0.0001).build()) | |
.l2(1e-5) | |
.dataSetFeatureMapping("tokenIdxs", "sentenceIdx") | |
.dataSetFeatureMaskMapping("mask") | |
.dataSetLabelMapping("label") | |
.build()); | |
MultiDataSetIterator trainIterator = getDataSetIterator(trainDataIterator, true); | |
MultiDataSetIterator testIterator = getDataSetIterator(testDataIterator, false); | |
System.out.println(); | |
System.out.println("==============================================================================================="); | |
System.out.println("=============================== Train (before) ====================================="); | |
System.out.println("==============================================================================================="); | |
Evaluation e = new Evaluation(); | |
bertModel.evaluate(trainIterator, "output", 0, e); | |
System.out.println(e.stats()); | |
System.out.println(); | |
System.out.println("==============================================================================================="); | |
System.out.println("=============================== Test (before) ====================================="); | |
System.out.println("==============================================================================================="); | |
e = new Evaluation(); | |
bertModel.evaluate(testIterator, "output", 0, e); | |
System.out.println(e.stats()); | |
for (int i = 1; i <= 4; i++) { | |
// train | |
bertModel.fit(trainIterator, 1); | |
System.out.println(); | |
System.out.println("==============================================================================================="); | |
System.out.println(String.format("=============================== Train Epoch %d =====================================", i)); | |
System.out.println("==============================================================================================="); | |
e = new Evaluation(); | |
// System.out.println("Train Prediction" + bertModel.output(trainIterator, "output")); | |
bertModel.evaluate(trainIterator, "output", 0, e); | |
System.out.println(String.format("Train Evaluation, end epoch %d:", i)); | |
System.out.println(e.stats()); | |
System.out.println(); | |
System.out.println("==============================================================================================="); | |
System.out.println(String.format("=============================== Test Epoch %d =====================================", i)); | |
System.out.println("==============================================================================================="); | |
e = new Evaluation(); | |
// System.out.println("Test Prediction" + bertModel.output(testIterator, "output")); | |
bertModel.evaluate(testIterator, "output", 0, e); | |
System.out.println(String.format("Test Evaluation, end epoch %d:", i)); | |
System.out.println(e.stats()); | |
} | |
} | |
protected MultiDataSetIterator getDataSetIterator(IResettableIterator<InputOutputTuple<MlText, MlDecimalVector>> trainDataIterator, boolean isRandomizeData) throws IOException { | |
LabelAwareIterator labelAwareIterator = new LabeledSentenceIterator(trainDataIterator, LABELS, isRandomizeData); | |
LabelAwareConverter sentenceProvider = new LabelAwareConverter(labelAwareIterator, LABELS); | |
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(BertSentenceClassification.class.getResourceAsStream("vocab.txt"), false, false, StandardCharsets.UTF_8); | |
return BertIterator.builder() | |
.tokenizer(t) | |
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, MAX_SENTENCE_LENGTH) | |
.minibatchSize(BATCH_SIZE) | |
.padMinibatches(true) | |
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) | |
.vocabMap(t.getVocab()) | |
.task(BertIterator.Task.SEQ_CLASSIFICATION) | |
.prependToken(PREPENDED_TOKEN) | |
.sentenceProvider(sentenceProvider) | |
.build(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment