Skip to content

Instantly share code, notes, and snippets.

@hotip
Created August 18, 2023 13:06
Show Gist options
  • Save hotip/0514521ffee37c3b49cef52dbad96f16 to your computer and use it in GitHub Desktop.
Save hotip/0514521ffee37c3b49cef52dbad96f16 to your computer and use it in GitHub Desktop.
XunFeiXingHuo-LLM
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.util.*
import io.ktor.utils.io.core.*
import io.ktor.websocket.*
import korlibs.time.DateFormat
import korlibs.time.DateTime
import korlibs.time.TimeSpan
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ProducerScope
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.serialization.json.*
import kotlinx.uuid.nextUUID
import org.kotlincrypto.macs.hmac.sha2.HmacSHA256
import kotlin.random.Random
class Xunfei private constructor(
authConfig: AuthConfig,
private val coroutineScope: CoroutineScope,
private val version: Version = Version.V2
) : LLMProvider<ChatMessage>() {
data class VersionConfig(val path: String, val domain: String)
data class AuthConfig(
val appId: String,
val apiKey: String,
val apiSecret: String,
)
companion object {
val configs = mapOf(
Version.V1 to VersionConfig("/v1.1/chat", "general"),
Version.V2 to VersionConfig("/v2.1/chat", "generalv2"),
)
enum class Version { V1, V2 }
/**
* 1.5 api
*/
fun v1(authConfig: AuthConfig, coroutineScope: CoroutineScope) = Xunfei(authConfig, coroutineScope, Version.V1)
/**
* 2.0 api
*/
fun v2(authConfig: AuthConfig, coroutineScope: CoroutineScope) = Xunfei(authConfig, coroutineScope)
}
private fun getPath() = configs[version]?.path ?: throw Exception("no such version")
private fun getDomain() = configs[version]?.domain ?: throw Exception("no such version")
private val date: String
get() {
val format = DateFormat("EEE, dd MMM yyyy HH:mm:ss z")
// 好好的中国 API 为什么要用美国时间 Locale.US
val date = DateTime.nowLocal().addOffset(TimeSpan(-8 * 60 * 60 * 1000.0))
return date.format(format)
}
private val xfHost = "spark-api.xf-yun.com"
private val path = getPath()
private val appId = authConfig.appId
private val apiSecret = authConfig.apiSecret
private val apiKey = authConfig.apiKey
private fun generateAuthorization(
apiKey: String,
apiSecret: String,
host: String,
date: String,
path: String
): String {
val requestLine = "GET $path HTTP/1.1"
val header = """
|host: $host
|date: $date
|$requestLine
""".trimMargin()
val hmacSha256 = HmacSHA256(apiSecret.toByteArray())
val signature = hmacSha256.doFinal(header.toByteArray()).encodeBase64()
val authorizationOrigin =
"""api_key="$apiKey", algorithm="hmac-sha256", headers="host date request-line", signature="$signature""""
println(signature)
return authorizationOrigin.toByteArray().encodeBase64()
}
private suspend fun getWebSocketSession(): DefaultClientWebSocketSession {
val client = HttpClient(CIO) {
install(WebSockets)
}
return client.webSocketSession {
url("wss://$xfHost$path")
parameter(
"authorization", generateAuthorization(
apiKey = apiKey,
apiSecret = apiSecret,
host = xfHost,
date = date,
path = path
)
)
parameter("date", date)
parameter("host", xfHost)
}
}
private val uid: String get() = Random.nextUUID().toString().take(32)
private var _sendingJob: Job = Job()
override fun textComplete(session: ChatSession<ChatMessage>) {
_sendingJob.cancel()
_sendingJob = Job()
coroutineScope.launch(_sendingJob + Dispatchers.IO) {
/* request sample
{
"header": {
"app_id": "12345",
"uid": "12345"
},
"parameter": {
"chat": {
"domain": "general",
"temperature": 0.5,
"max_tokens": 1024,
}
},
"payload": {
"message": {
# 如果想获取结合上下文的回答,需要开发者每次将历史问答信息一起传给服务端,如下示例
# 注意:text里面的所有content内容加一起的tokens需要控制在8192以内,开发者如有较长对话需求,需要适当裁剪历史信息
"text": [
{"role": "user", "content": "你是谁"} # 用户的历史问题
{"role": "assistant", "content": "....."} # AI的历史回答结果
# ....... 省略的历史对话
{"role": "user", "content": "你会做什么"} # 最新的一条问题,如无需上下文,可只传最新一条问题
]
}
}
}
*/
/*
response sample:
{
"header":{
"code":0,
"message":"Success",
"sid":"cht000cb087@dx18793cd421fb894542",
"status":2
},
"payload":{
"choices":{
"status":2,
"seq":0,
"text":[
{
"content":"我可以帮助你的吗?",
"role":"assistant",
"index":0
}
]
},
"usage":{
"text":{
"question_tokens":4,
"prompt_tokens":5,
"completion_tokens":9,
"total_tokens":14
}
}
}
}
*/
val webSocketSession = getWebSocketSession()
webSocketSession.apply {
send(buildJsonObject {
put("header", buildJsonObject {
put("app_id", appId)
put("uid", uid)
})
put("parameter", buildJsonObject {
put("chat", buildJsonObject {
put("domain", getDomain())
put("temperature", 0.5)
put("max_tokens", 1024)
})
})
put("payload", buildJsonObject {
put("message", buildJsonObject {
put("text", JsonArray(
session.chatHistory.map { msg ->
buildJsonObject {
put("role", msg.role.name.lowercase())
put("content", msg.text)
}
}
))
})
})
}.toString())
var end = false
var newSession = session
while (!end) {
val frame =
incoming.receiveCatching().getOrNull()
if (frame is Frame.Text) {
val response = frame.readText()
println(response)
newSession = responseToChatSession(newSession, response)
println("${newSession.chatHistory.size}")
_producer.trySend(newSession)
end = newSession.status == MessageStatus.END
}
}
}
}
}
private lateinit var _producer: ProducerScope<ChatSession<ChatMessage>>
override val response: Flow<ChatSession<ChatMessage>> =
callbackFlow {
_producer = this
awaitClose()
println("end")
}
override fun imageGenerate(prompt: String): Flow<ChatSession<ChatMessage>> {
TODO()
}
override fun responseToChatSession(
sendingSession: ChatSession<ChatMessage>,
responseBody: String
): ChatSession<ChatMessage> {
val newMessage: ChatMessage = responseBody.toMessage()
val lastMessage = sendingSession.chatHistory.last()
return if (lastMessage.status != MessageStatus.END) {
val newLastMessage = ChatMessage(
text = lastMessage.text + newMessage.text,
status = newMessage.status,
role = lastMessage.role
)
sendingSession.copy(chatHistory = sendingSession.chatHistory.toMutableList().apply {
removeLast()
add(newLastMessage)
}, status = newMessage.status)
} else {
sendingSession.copy(chatHistory = sendingSession.chatHistory.toMutableList().apply {
add(newMessage)
}, status = newMessage.status)
}
}
private fun String.toMessage(): ChatMessage {
return runCatching {
val element = Json.parseToJsonElement(this)
val choices = element.jsonObject["payload"]!!.jsonObject["choices"]!!
val statusCode: Int = choices.jsonObject["status"]?.jsonPrimitive?.int!!
val message = choices.jsonObject["text"]!!.jsonArray[0]
val status: MessageStatus = when (statusCode) {
0 -> MessageStatus.BEGIN
1 -> MessageStatus.CONTENT
else -> MessageStatus.END
}
val text: String = message.jsonObject["content"]!!.jsonPrimitive.content
ChatMessage(text = text, role = Role2.Assistant, status = status)
}.getOrElse { err ->
ChatMessage(
text = "error: \nmsg: $this \nerror:${err.cause} ${err.message}",
role = Role2.Assistant,
status = MessageStatus.END
)
}
}
}
/**
* LLMProvider provide only session-free interfaces
*
* It's LLMProvider's responsibility to maintain the network connection
* But the chat session is maintained by the client
*
* The implementations should provide a way to convert the response to a ChatSession
*/
abstract class LLMProvider<M : ChatMessage> {
abstract val response: Flow<ChatSession<M>>
abstract fun textComplete(session: ChatSession<ChatMessage>)
abstract fun imageGenerate(prompt: String): Flow<ChatSession<M>>
protected abstract fun responseToChatSession(sendingSession: ChatSession<M>, responseBody: String): ChatSession<M>
}
enum class Role2 {
USER, BOT, Assistant
}
/**
* 后台返回消息状态:[BEGIN] 刚开始返回 [CONTENT] 中间内容 [END] 内容结束
*/
enum class MessageStatus {
BEGIN, CONTENT, END
}
open class ChatMessage(
val text: String,
val role: Role2,
val status: MessageStatus = MessageStatus.CONTENT,
)
data class ChatSession<out T : ChatMessage>(
val conversionName: String,
val chatHistory: List<T>,
/**
* 当前会话状态,为 [MessageStatus.END] 时才允许发下一个对话
*/
var status: MessageStatus = MessageStatus.END
)
fun <T: ChatMessage> ChatSession<T>.canSendNextMessage() = this.status == MessageStatus.END
fun <T: ChatMessage> ChatSession<T>.appendMessage(newMessage: T): ChatSession<T> =
this.copy(chatHistory = this.chatHistory.toMutableList().apply {
add(newMessage)
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment