Skip to content

Instantly share code, notes, and snippets.

@dacr
Created January 29, 2024 11:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/1dd65c1b86ea1ff64b5eb1b7683c2ea4 to your computer and use it in GitHub Desktop.
Save dacr/1dd65c1b86ea1ff64b5eb1b7683c2ea4 to your computer and use it in GitHub Desktop.
bert question answer using DJL / published by https://github.com/dacr/code-examples-manager #8d02579b-7065-47b1-9d1a-045ed6bd9c05/9ccacc46529df169660889fdd29d48110e74be76
// summary : bert question answer using DJL
// keywords : djl, machine-learning, tutorial, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 8d02579b-7065-47b1-9d1a-045ed6bd9c05
// created-on : 2024-01-29T11:58:43+01:00
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.11"
//> using dep "org.slf4j:slf4j-simple:2.0.11"
//> using dep "net.java.dev.jna:jna:5.14.0"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:basicdataset:0.26.0"
//> using dep "ai.djl:model-zoo:0.26.0"
//> using dep "ai.djl.huggingface:tokenizers:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-engine:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-model-zoo:0.26.0"
//> using dep "ai.djl.tensorflow:tensorflow-engine:0.26.0"
//> using dep "ai.djl.tensorflow:tensorflow-model-zoo:0.26.0"
//> using dep "ai.djl.paddlepaddle:paddlepaddle-engine:0.26.0"
//> using dep "ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.26.0"
//> using dep "ai.djl.onnxruntime:onnxruntime-engine:0.26.0"
// ---------------------
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "error")
import ai.djl.Application
import ai.djl.engine.Engine
import ai.djl.modality.Classifications
import ai.djl.repository.zoo.Criteria
import ai.djl.training.util.ProgressBar
import ai.djl.huggingface.translator.{TextClassificationTranslatorFactory, TextEmbeddingTranslatorFactory}
import ai.djl.modality.nlp.qa.QAInput
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW}
val criteria =
Criteria.builder
.setTypes(classOf[QAInput], classOf[String])
.optFilter("backbone", "bert")
.optProgress(new ProgressBar)
.build
val model = criteria.loadModel()
val predictor = model.newPredictor()
val question1 = "When did BBC Japan start broadcasting?"
val question2 = "When did BBC Japan end broadcasting?"
val question3 = "What is BBC Japan?"
val paragraph =
"""
|BBC Japan was a general entertainment Channel.
|Which operated between December 2004 and April 2006.
|It ceased operations after its Japanese distributor folded.
|""".stripMargin
println(s"$question1 ${predictor.predict(QAInput(question1, paragraph))}")
println(s"$question2 ${predictor.predict(QAInput(question2, paragraph))}")
println(s"$question3 ${predictor.predict(QAInput(question3, paragraph))}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment