Skip to content

Instantly share code, notes, and snippets.

@weinino
Last active June 9, 2020 11:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save weinino/2b8796acc271dab8df4f3d5e6e807191 to your computer and use it in GitHub Desktop.
Save weinino/2b8796acc271dab8df4f3d5e6e807191 to your computer and use it in GitHub Desktop.
Fine-tuning BERT model in DL4J
/**
* Model imported: C:/workspaces/bert/data/multi_cased_L-12_H-768_A-12/frozen/bert_frozen_mb4_len512.pb
*/
// Constants
public final static String PREPENDED_TOKEN = "[CLS]"; // token inserted at the beginning of each sentence
public final static int BATCH_SIZE = 1;
public final static int MAX_SENTENCE_LENGTH = 128;
public final static List<String> LABELS = Arrays.asList("A", "B");
public void train(IResettableIterator<InputOutputTuple<MlText, MlDecimalVector>> trainDataIterator, IResettableIterator<InputOutputTuple<MlText, MlDecimalVector>> testDataIterator) throws IOException {
SameDiff bertModel = TFGraphMapper.importGraph(BertSentenceClassification.class.getResourceAsStream("bert_frozen_mb4_len512.pb"));
bertModel.getVariable("Placeholder").setShape(BATCH_SIZE, MAX_SENTENCE_LENGTH);
bertModel.getVariable("Placeholder_1").setShape(BATCH_SIZE, MAX_SENTENCE_LENGTH);
bertModel.getVariable("Placeholder_2").setShape(BATCH_SIZE, MAX_SENTENCE_LENGTH);
bertModel.renameVariable("Placeholder", "tokenIdxs");
bertModel.renameVariable("Placeholder_1", "mask");
bertModel.renameVariable("Placeholder_2", "sentenceIdx"); //only ever 0, but needed by this model...
Set<String> floatConstants = new HashSet<>(Arrays.asList(
"bert/encoder/ones"));
//For training, convert weights and biases from constants to variables:
for (SDVariable v : bertModel.variables()) {
if (v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.name()) && !v.name().startsWith("bert/embeddings")) { //Skip scalars - trainable params
v.convertToVariable();
}
}
int sizeOfClassificationLayer = 768;
int numberOfClasses = LABELS.size();
SDVariable weights = bertModel.var("output_weights", new XavierInitScheme('c', sizeOfClassificationLayer, numberOfClasses), DataType.FLOAT, sizeOfClassificationLayer, numberOfClasses);
SDVariable bias = bertModel.var("output_bias", new ZeroInitScheme('c'), DataType.FLOAT, numberOfClasses);
SDVariable linear = bertModel.mmul("output_dense", bertModel.getVariable("bert/pooler/dense/Tanh"), weights).add(bias);
SDVariable softmax = bertModel.nn.softmax("output", linear);
//For training, we'll need to add a label placeholder for one-hot labels:
SDVariable label = bertModel.placeHolder("label", DataType.FLOAT, 1, numberOfClasses);
bertModel.loss().logLoss("loss",label, softmax);
//Set up training configuration...
bertModel.setTrainingConfig(TrainingConfig.builder()
.updater(new Sgd.Builder().learningRate(0.0001).build())
.l2(1e-5)
.dataSetFeatureMapping("tokenIdxs", "sentenceIdx")
.dataSetFeatureMaskMapping("mask")
.dataSetLabelMapping("label")
.build());
MultiDataSetIterator trainIterator = getDataSetIterator(trainDataIterator, true);
MultiDataSetIterator testIterator = getDataSetIterator(testDataIterator, false);
System.out.println();
System.out.println("===============================================================================================");
System.out.println("=============================== Train (before) =====================================");
System.out.println("===============================================================================================");
Evaluation e = new Evaluation();
bertModel.evaluate(trainIterator, "output", 0, e);
System.out.println(e.stats());
System.out.println();
System.out.println("===============================================================================================");
System.out.println("=============================== Test (before) =====================================");
System.out.println("===============================================================================================");
e = new Evaluation();
bertModel.evaluate(testIterator, "output", 0, e);
System.out.println(e.stats());
for (int i = 1; i <= 4; i++) {
// train
bertModel.fit(trainIterator, 1);
System.out.println();
System.out.println("===============================================================================================");
System.out.println(String.format("=============================== Train Epoch %d =====================================", i));
System.out.println("===============================================================================================");
e = new Evaluation();
// System.out.println("Train Prediction" + bertModel.output(trainIterator, "output"));
bertModel.evaluate(trainIterator, "output", 0, e);
System.out.println(String.format("Train Evaluation, end epoch %d:", i));
System.out.println(e.stats());
System.out.println();
System.out.println("===============================================================================================");
System.out.println(String.format("=============================== Test Epoch %d =====================================", i));
System.out.println("===============================================================================================");
e = new Evaluation();
// System.out.println("Test Prediction" + bertModel.output(testIterator, "output"));
bertModel.evaluate(testIterator, "output", 0, e);
System.out.println(String.format("Test Evaluation, end epoch %d:", i));
System.out.println(e.stats());
}
}
protected MultiDataSetIterator getDataSetIterator(IResettableIterator<InputOutputTuple<MlText, MlDecimalVector>> trainDataIterator, boolean isRandomizeData) throws IOException {
LabelAwareIterator labelAwareIterator = new LabeledSentenceIterator(trainDataIterator, LABELS, isRandomizeData);
LabelAwareConverter sentenceProvider = new LabelAwareConverter(labelAwareIterator, LABELS);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(BertSentenceClassification.class.getResourceAsStream("vocab.txt"), false, false, StandardCharsets.UTF_8);
return BertIterator.builder()
.tokenizer(t)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, MAX_SENTENCE_LENGTH)
.minibatchSize(BATCH_SIZE)
.padMinibatches(true)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.prependToken(PREPENDED_TOKEN)
.sentenceProvider(sentenceProvider)
.build();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment