Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active May 25, 2024 10:19
Show Gist options
  • Save dacr/8ff63bb72a0eac2ad15a3f98e2c94c0a to your computer and use it in GitHub Desktop.
Save dacr/8ff63bb72a0eac2ad15a3f98e2c94c0a to your computer and use it in GitHub Desktop.
using mistral with DJL refactored / published by https://github.com/dacr/code-examples-manager #2c2815f3-f004-4532-b47f-4f664516e0b5/f2c6458c153065abeea53d020e59e98269c2da09
// summary : using mistral with DJL refactored
// keywords : djl, machine-learning, llm, mistral, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 2c2815f3-f004-4532-b47f-4f664516e0b5
// created-on : 2024-02-18T11:13:34+01:00
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.4.2"
//> using dep "org.slf4j:slf4j-api:2.0.13"
//> using dep "org.slf4j:slf4j-simple:2.0.13"
//> using dep "ai.djl:api:0.28.0"
//> using dep "ai.djl:basicdataset:0.28.0"
//> using dep "ai.djl.llama:llama:0.28.0"
//> using dep "ai.djl.pytorch:pytorch-engine:0.28.0"
//> using dep "ai.djl.huggingface:tokenizers:0.28.0"
// ---------------------
/* Thank Scala.IO and NuMind and of course DJL !
https://github.com/numind-tech/scalaio_2024/blob/main/src/main/scala/chatbot/Chatbot.scala
------
djl://ai.djl.huggingface.gguf/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/0.0.1/Q4_K_M, ai.djl.huggingface.gguf/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/0.0.1/Q4_K_M
https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true
*/
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "debug")
//System.setProperty("PYTORCH_FLAVOR", "cu118")
System.setProperty("PYTORCH_FLAVOR", "cu123")
//System.setProperty("PYTORCH_FLAVOR", "cpu")
System.setProperty("ai.djl.pytorch.graph_optimizer", "false")
/*
On my server : Cuda 12.3 (nvidia-smi)
DJL latest supported Cuda (0.26.0) : Cuda 11.8
*/
import ai.djl.repository.zoo.Criteria
import ai.djl.training.util.ProgressBar
import ai.djl.llama.engine.LlamaInput
import ai.djl.llama.engine.LlamaTranslatorFactory
import ai.djl.llama.jni.Token
import ai.djl.llama.jni.TokenIterator
import scala.jdk.CollectionConverters.*
import scala.util.chaining.*
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW}
// https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF
val name = "LLM"
val modelId = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
val quantMethod = "Q4_K_M" // SIZE=4.37Gb MAX_RAM=6.87Gb
val url = s"djl://ai.djl.huggingface.gguf/$modelId/0.0.1/$quantMethod"
val criteria =
Criteria.builder
.setTypes(classOf[LlamaInput], classOf[TokenIterator])
.optModelUrls(url)
.optOption("number_gpu_layers", "43")
.optTranslatorFactory(new LlamaTranslatorFactory())
.optProgress(new ProgressBar)
.build
val model = criteria.loadModel()
val predictor = model.newPredictor()
val param = new LlamaInput.Parameters()
param.setTemperature(0.7f)
param.setPenalizeNl(true)
param.setMirostat(2)
param.setAntiPrompt(Array("User: "))
val in = new LlamaInput()
in.setParameters(param)
def interact(currentPrompt: String, nextInput: String)(newResponseToken: String => Unit): String = {
val morePrompt = s"\nUser: $nextInput\n$name: "
val updatedPrompt = currentPrompt + morePrompt
in.setInputs(updatedPrompt)
val it = predictor.predict(in)
val tokens = it.asScala.map(_.getText.tap(newResponseToken)).toList
val resultPrompt = updatedPrompt + tokens.mkString
resultPrompt
}
val systemPrompt =
s"""As a computer science teacher, I make my best to help my students to become software experts.
|
|$name: How may I help you today ?""".stripMargin
val finalPrompt =
List(
"What is a monad ?",
"Could you give me a scala example ?",
"Thank you very much teacher !"
).foldLeft(systemPrompt){ case (currentPrompt, nextInput) =>
print(s"${YELLOW}$nextInput$RESET")
interact(currentPrompt, nextInput)(print)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment