Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active February 18, 2024 18:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/50bd670c4ec83a17f9169838281e823b to your computer and use it in GitHub Desktop.
Save dacr/50bd670c4ec83a17f9169838281e823b to your computer and use it in GitHub Desktop.
using mistral with DJL / published by https://github.com/dacr/code-examples-manager #389e67ca-de9a-4f47-a1c0-504564fb2dbe/d8a4fd23108a58830d86f5c8bac5315045f00167
// summary : using mistral with DJL
// keywords : djl, machine-learning, llm, mistral, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 389e67ca-de9a-4f47-a1c0-504564fb2dbe
// created-on : 2024-02-03T14:34:48+01:00
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.12"
//> using dep "org.slf4j:slf4j-simple:2.0.12"
//> using dep "net.java.dev.jna:jna:5.14.0"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:basicdataset:0.26.0"
//> using dep "ai.djl.llama:llama:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-engine:0.26.0"
// ---------------------
/* Thank Scala.IO and NuMind and of course DJL !
https://github.com/numind-tech/scalaio_2024/blob/main/src/main/scala/chatbot/Chatbot.scala
------
djl://ai.djl.huggingface.gguf/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/0.0.1/Q4_K_M, ai.djl.huggingface.gguf/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/0.0.1/Q4_K_M
https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true
*/
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "error")
import ai.djl.repository.zoo.Criteria
import ai.djl.training.util.ProgressBar
import ai.djl.llama.engine.LlamaInput
import ai.djl.llama.engine.LlamaTranslatorFactory
import ai.djl.llama.jni.Token
import ai.djl.llama.jni.TokenIterator
import scala.jdk.CollectionConverters.*
import scala.util.chaining.*
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW}
val name = "LLM"
val modelId = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
val quantMethod = "Q4_K_M"
val url = s"djl://ai.djl.huggingface.gguf/$modelId/0.0.1/$quantMethod"
val criteria =
Criteria.builder
.setTypes(classOf[LlamaInput], classOf[TokenIterator])
.optModelUrls(url)
.optOption("number_gpu_layers", "43")
.optTranslatorFactory(new LlamaTranslatorFactory())
.optProgress(new ProgressBar)
.build
val model = criteria.loadModel()
val predictor = model.newPredictor()
val param = new LlamaInput.Parameters()
param.setTemperature(0.7f)
param.setPenalizeNl(true)
param.setMirostat(2)
param.setAntiPrompt(Array("User: "))
val in = new LlamaInput()
in.setParameters(param)
val systemPrompt =
s"""As a computer science teacher, I make my best to help my students to become software experts.
|
|$name: How may I help you today ?""".stripMargin
val prompt = StringBuilder(systemPrompt)
def interact(nextInput: String): String = {
val morePrompt = s"\nUser: $nextInput\n$name: "
print(s"${BLUE}$morePrompt$RESET")
prompt.append(morePrompt)
in.setInputs(prompt.toString())
val it = predictor.predict(in)
val tokens = it.asScala.map(_.getText.tap(print)).toList
prompt.append(tokens.mkString)
tokens.mkString
}
interact("What is a monad ?")
interact("Could you give me a scala example ?")
interact("Thank you very much teacher !")
println()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment