Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active May 25, 2024 10:20
Show Gist options
  • Save dacr/1dd65c1b86ea1ff64b5eb1b7683c2ea4 to your computer and use it in GitHub Desktop.
Save dacr/1dd65c1b86ea1ff64b5eb1b7683c2ea4 to your computer and use it in GitHub Desktop.
bert question answer using DJL / published by https://github.com/dacr/code-examples-manager #8d02579b-7065-47b1-9d1a-045ed6bd9c05/b63edf7589b5069dcd2299a8c695a32ec1d03d6a
// summary : bert question answer using DJL
// keywords : djl, machine-learning, tutorial, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 8d02579b-7065-47b1-9d1a-045ed6bd9c05
// created-on : 2024-01-29T11:58:43+01:00
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.4.2"
//> using dep "org.slf4j:slf4j-api:2.0.13"
//> using dep "org.slf4j:slf4j-simple:2.0.13"
//> using dep "net.java.dev.jna:jna:5.14.0"
//> using dep "ai.djl:api:0.28.0"
//> using dep "ai.djl:basicdataset:0.28.0"
//> using dep "ai.djl:model-zoo:0.28.0"
//> using dep "ai.djl.huggingface:tokenizers:0.28.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.28.0"
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.28.0"
//> using dep "ai.djl.pytorch:pytorch-engine:0.28.0"
//> using dep "ai.djl.pytorch:pytorch-model-zoo:0.28.0"
//> using dep "ai.djl.tensorflow:tensorflow-engine:0.28.0"
//> using dep "ai.djl.tensorflow:tensorflow-model-zoo:0.28.0"
////> using dep "ai.djl.paddlepaddle:paddlepaddle-engine:0.28.0"
////> using dep "ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.28.0"
//> using dep "ai.djl.onnxruntime:onnxruntime-engine:0.28.0"
// ---------------------
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "error")
import ai.djl.Application
import ai.djl.engine.Engine
import ai.djl.modality.Classifications
import ai.djl.repository.zoo.Criteria
import ai.djl.training.util.ProgressBar
import ai.djl.huggingface.translator.{TextClassificationTranslatorFactory, TextEmbeddingTranslatorFactory}
import ai.djl.modality.nlp.qa.QAInput
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW}
val criteria =
Criteria.builder
.setTypes(classOf[QAInput], classOf[String])
.optFilter("backbone", "bert")
.optProgress(new ProgressBar)
.build
val model = criteria.loadModel()
val predictor = model.newPredictor()
val question1 = "When did BBC Japan start broadcasting?"
val question2 = "When did BBC Japan end broadcasting?"
val question3 = "What is BBC Japan?"
val paragraph =
"""
|BBC Japan was a general entertainment Channel.
|Which operated between December 2004 and April 2006.
|It ceased operations after its Japanese distributor folded.
|""".stripMargin
println(s"$question1 ${predictor.predict(QAInput(question1, paragraph))}")
println(s"$question2 ${predictor.predict(QAInput(question2, paragraph))}")
println(s"$question3 ${predictor.predict(QAInput(question3, paragraph))}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment