This demo example is built with gradle.
BertTranslator.java
and HuggingFaceQaInference.java
are stored in directory "src/main/java".
bert-base-cased-vocab.txt
and trace_cased_bertqa.pt
are stored in directory "scr/main/resources"
The two recources are downloaded from:
trace_cased_bertqa.pt
: https://mlrepo.djl.ai/model/nlp/question_answer/ai/djl/pytorch/bertqa/trace_cased_bertqa/0.0.1/trace_cased_bertqa.pt.gz
bert-base-cased-vocab.txt
: https://mlrepo.djl.ai/model/nlp/question_answer/ai/djl/pytorch/bertqa/trace_cased_bertqa/0.0.1/bert-base-cased-vocab.txt.gz
See the main text for more info.
Note that, when running this code, the default engine may also need to be specified with the VM option: -Dai.djl.default_engine=PyTorch
which is compatible with the model and the tokenizer.
The code at https://raw.githubusercontent.com/deepjavalibrary/djl/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java worked fine for me.
The framework downloads the vocabulary and pytorch model automatically and saves it (on windows) to C:\Users\\.djl.ai\cache\repo\model\nlp\question_answer\ai\djl\pytorch\bertqa\bert\false\SQuAD\0.0.1
Here are the maven dependencies, i also have a VM argument "-Dai.djl.default_engine=PyTorch"
Note: I get a DLL error when I invoke the same working code from an application that loaded a tensorflow model before this one. I'm pretty sure its because I'm using both DJL and TensorFLow libraries together, will try and switch to one (DJL probably)