Skip to content

Instantly share code, notes, and snippets.

@KexinFeng
Last active January 4, 2024 08:41
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save KexinFeng/97e6344556f88822650d023acfbdf4f5 to your computer and use it in GitHub Desktop.
Save KexinFeng/97e6344556f88822650d023acfbdf4f5 to your computer and use it in GitHub Desktop.
Deploying HuggingFace QA model in Java

This demo example is built with gradle.

The project structure

BertTranslator.java and HuggingFaceQaInference.java are stored in directory "src/main/java". bert-base-cased-vocab.txt and trace_cased_bertqa.pt are stored in directory "scr/main/resources"

Recource dependency

The two recources are downloaded from: trace_cased_bertqa.pt: https://mlrepo.djl.ai/model/nlp/question_answer/ai/djl/pytorch/bertqa/trace_cased_bertqa/0.0.1/trace_cased_bertqa.pt.gz bert-base-cased-vocab.txt: https://mlrepo.djl.ai/model/nlp/question_answer/ai/djl/pytorch/bertqa/trace_cased_bertqa/0.0.1/bert-base-cased-vocab.txt.gz

See the main text for more info.

Note that, when running this code, the default engine may also need to be specified with the VM option: -Dai.djl.default_engine=PyTorch which is compatible with the model and the tokenizer.

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
public class BertTranslator implements Translator<QAInput, String> {
private List<String> tokens;
private Vocabulary vocabulary;
private BertTokenizer tokenizer;
@Override
public void prepare(TranslatorContext ctx) throws IOException {
Path path = Paths.get("src/main/resources/bert-base-cased-vocab.txt");
vocabulary = DefaultVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(path)
.optUnknownToken("[UNK]")
.build();
tokenizer = new BertTokenizer();
}
@Override
public NDList processInput(TranslatorContext ctx, QAInput input) throws IOException {
BertToken token =
tokenizer.encode(
input.getQuestion().toLowerCase(),
input.getParagraph().toLowerCase());
// get the encoded tokens that would be used in processOutput
tokens = token.getTokens();
NDManager manager = ctx.getNDManager();
// map the tokens(String) to indices(long)
long[] indices = tokens.stream().mapToLong(vocabulary::getIndex).toArray();
long[] attentionMask = token.getAttentionMask().stream().mapToLong(i -> i).toArray();
long[] tokenType = token.getTokenTypes().stream().mapToLong(i -> i).toArray();
NDArray indicesArray = manager.create(indices);
NDArray attentionMaskArray =
manager.create(attentionMask);
NDArray tokenTypeArray = manager.create(tokenType);
// The order matters
return new NDList(indicesArray, attentionMaskArray, tokenTypeArray);
}
@Override
public String processOutput(TranslatorContext ctx, NDList list) {
NDArray startLogits = list.get(0);
NDArray endLogits = list.get(1);
int startIdx = (int) startLogits.argMax().getLong();
int endIdx = (int) endLogits.argMax().getLong();
return tokenizer.tokenToString(tokens.subList(startIdx, endIdx + 1));
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
plugins {
id 'java'
}
repositories {
mavenCentral()
}
dependencies {
implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.17.1"
implementation platform("ai.djl:bom:0.16.0")
implementation "ai.djl:api"
runtimeOnly "ai.djl.pytorch:pytorch-engine"
runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
}
test {
useJUnitPlatform()
}
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Paths;
public class HuggingFaceQaInference {
public static void main(String[] args) throws IOException, TranslateException, ModelException {
String question = "When did BBC Japan start broadcasting?";
String paragraph =
"BBC Japan was a general entertainment Channel. "
+ "Which operated between December 2004 and April 2006. "
+ "It ceased operations after its Japanese distributor folded.";
QAInput input = new QAInput(question, paragraph);
String answer = HuggingFaceQaInference.qa_predict(input);
System.out.println("The answer is: \n" + answer);
}
public static String qa_predict(QAInput input) throws IOException, TranslateException, ModelException {
BertTranslator translator = new BertTranslator();
Criteria<QAInput, String> criteria = Criteria.builder()
.setTypes(QAInput.class, String.class)
.optModelPath(Paths.get("src/main/resources/trace_cased_bertqa.pt"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel<QAInput, String> model = criteria.loadModel();
try (Predictor<QAInput, String> predictor = model.newPredictor(translator)) {
return predictor.predict(input);
}
}
}
@hakanai
Copy link

hakanai commented Feb 16, 2023

It seems another resource file is needed:

Exception in thread "main" java.io.FileNotFoundException: Parameter file with prefix: trace_cased_bertqa.pt not found in: ...\path\to\project\data\trace_cased_bertqa.pt or not readable by the engine.
	at ai.djl.mxnet.engine.MxModel.load(MxModel.java:109)
	at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:159)
	at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:168)

The error makes it sound like maybe there's meant to be another level of directories in here?

@willygoldgewicht
Copy link

In 👍
Criteria<QAInput, String> criteria = Criteria.builder()
.setTypes(QAInput.class, String.class)
.optModelPath(Paths.get("src/main/resources/trace_cased_bertqa.pt"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();

I_get this error :
java.io.FileNotFoundException: trace_cased_bertqa.pt file not found

Why?
Thank's

@KexinFeng
Copy link
Author

KexinFeng commented Mar 21, 2023

@hakanai

It seems another resource file is needed:

Exception in thread "main" java.io.FileNotFoundException: Parameter file with prefix: trace_cased_bertqa.pt not found in: ...\path\to\project\data\trace_cased_bertqa.pt or not readable by the engine.
	at ai.djl.mxnet.engine.MxModel.load(MxModel.java:109)
	at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:159)
	at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:168)

The error makes it sound like maybe there's meant to be another level of directories in here?

It shows that you are using MXNet engine. You need to specify the default engine to be PyTorch in order for the example to work. This is done by VM option: -Dai.djl.default_engine=PyTorch

@KexinFeng
Copy link
Author

KexinFeng commented Mar 21, 2023

@willygoldgewicht

In Criteria<QAInput, String> criteria = Criteria.builder() .setTypes(QAInput.class, String.class) .optModelPath(Paths.get("src/main/resources/trace_cased_bertqa.pt")) .optTranslator(translator) .optProgress(new ProgressBar()).build();

I_get this error : java.io.FileNotFoundException: trace_cased_bertqa.pt file not found

Why? Thank's

As indicated the model file is not found when excuting the code. You can maybe try with absolute path in optModelPath().

@KexinFeng
Copy link
Author

KexinFeng commented Mar 21, 2023

To all readers:
Any issue related to this example can also be asked in DJL github repo, which may be noticed and replied faster. The relevant example in DJL is in
~/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java

@muthuishere
Copy link

i have an issue Caused by: java.lang.UnsatisfiedLinkError: /Users/xxx/.djl.ai/pytorch/1.13.1-cpu-osx-x86_64/0.21.0-libdjl_torch.dylib: dlopen(/Users/xxx/.djl.ai/pytorch/1.13.1-cpu-osx-x86_64/0.21.0-libdjl_torch.dylib, 1): Symbol not found: __ZNSt3__113basic_filebufIcNS_11char_traitsIcEEE4openEPKcj

what should be the issue

@codeoflife
Copy link

codeoflife commented Jul 19, 2023

The code at https://raw.githubusercontent.com/deepjavalibrary/djl/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java worked fine for me.
The framework downloads the vocabulary and pytorch model automatically and saves it (on windows) to C:\Users\\.djl.ai\cache\repo\model\nlp\question_answer\ai\djl\pytorch\bertqa\bert\false\SQuAD\0.0.1

Here are the maven dependencies, i also have a VM argument "-Dai.djl.default_engine=PyTorch"

<!-- https://mvnrepository.com/artifact/ai.djl/api -->
<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>api</artifactId>
    <version>0.23.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine -->
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.23.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-model-zoo -->
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-model-zoo</artifactId>
    <version>0.23.0</version>
</dependency>

Note: I get a DLL error when I invoke the same working code from an application that loaded a tensorflow model before this one. I'm pretty sure its because I'm using both DJL and TensorFLow libraries together, will try and switch to one (DJL probably)

Caused by: java.lang.UnsatisfiedLinkError: C:\Users<user>.djl.ai\pytorch\1.13.1-cpu-win-x86_64\torch_cpu.dll: The operating system cannot run %1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment