Created
August 18, 2023 13:06
-
-
Save hotip/0514521ffee37c3b49cef52dbad96f16 to your computer and use it in GitHub Desktop.
XunFeiXingHuo-LLM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io.ktor.client.* | |
import io.ktor.client.engine.cio.* | |
import io.ktor.client.plugins.websocket.* | |
import io.ktor.client.request.* | |
import io.ktor.util.* | |
import io.ktor.utils.io.core.* | |
import io.ktor.websocket.* | |
import korlibs.time.DateFormat | |
import korlibs.time.DateTime | |
import korlibs.time.TimeSpan | |
import kotlinx.coroutines.* | |
import kotlinx.coroutines.channels.ProducerScope | |
import kotlinx.coroutines.channels.awaitClose | |
import kotlinx.coroutines.flow.Flow | |
import kotlinx.coroutines.flow.callbackFlow | |
import kotlinx.serialization.json.* | |
import kotlinx.uuid.nextUUID | |
import org.kotlincrypto.macs.hmac.sha2.HmacSHA256 | |
import kotlin.random.Random | |
class Xunfei private constructor( | |
authConfig: AuthConfig, | |
private val coroutineScope: CoroutineScope, | |
private val version: Version = Version.V2 | |
) : LLMProvider<ChatMessage>() { | |
data class VersionConfig(val path: String, val domain: String) | |
data class AuthConfig( | |
val appId: String, | |
val apiKey: String, | |
val apiSecret: String, | |
) | |
companion object { | |
val configs = mapOf( | |
Version.V1 to VersionConfig("/v1.1/chat", "general"), | |
Version.V2 to VersionConfig("/v2.1/chat", "generalv2"), | |
) | |
enum class Version { V1, V2 } | |
/** | |
* 1.5 api | |
*/ | |
fun v1(authConfig: AuthConfig, coroutineScope: CoroutineScope) = Xunfei(authConfig, coroutineScope, Version.V1) | |
/** | |
* 2.0 api | |
*/ | |
fun v2(authConfig: AuthConfig, coroutineScope: CoroutineScope) = Xunfei(authConfig, coroutineScope) | |
} | |
private fun getPath() = configs[version]?.path ?: throw Exception("no such version") | |
private fun getDomain() = configs[version]?.domain ?: throw Exception("no such version") | |
private val date: String | |
get() { | |
val format = DateFormat("EEE, dd MMM yyyy HH:mm:ss z") | |
// 好好的中国 API 为什么要用美国时间 Locale.US | |
val date = DateTime.nowLocal().addOffset(TimeSpan(-8 * 60 * 60 * 1000.0)) | |
return date.format(format) | |
} | |
private val xfHost = "spark-api.xf-yun.com" | |
private val path = getPath() | |
private val appId = authConfig.appId | |
private val apiSecret = authConfig.apiSecret | |
private val apiKey = authConfig.apiKey | |
private fun generateAuthorization( | |
apiKey: String, | |
apiSecret: String, | |
host: String, | |
date: String, | |
path: String | |
): String { | |
val requestLine = "GET $path HTTP/1.1" | |
val header = """ | |
|host: $host | |
|date: $date | |
|$requestLine | |
""".trimMargin() | |
val hmacSha256 = HmacSHA256(apiSecret.toByteArray()) | |
val signature = hmacSha256.doFinal(header.toByteArray()).encodeBase64() | |
val authorizationOrigin = | |
"""api_key="$apiKey", algorithm="hmac-sha256", headers="host date request-line", signature="$signature"""" | |
println(signature) | |
return authorizationOrigin.toByteArray().encodeBase64() | |
} | |
private suspend fun getWebSocketSession(): DefaultClientWebSocketSession { | |
val client = HttpClient(CIO) { | |
install(WebSockets) | |
} | |
return client.webSocketSession { | |
url("wss://$xfHost$path") | |
parameter( | |
"authorization", generateAuthorization( | |
apiKey = apiKey, | |
apiSecret = apiSecret, | |
host = xfHost, | |
date = date, | |
path = path | |
) | |
) | |
parameter("date", date) | |
parameter("host", xfHost) | |
} | |
} | |
private val uid: String get() = Random.nextUUID().toString().take(32) | |
private var _sendingJob: Job = Job() | |
override fun textComplete(session: ChatSession<ChatMessage>) { | |
_sendingJob.cancel() | |
_sendingJob = Job() | |
coroutineScope.launch(_sendingJob + Dispatchers.IO) { | |
/* request sample | |
{ | |
"header": { | |
"app_id": "12345", | |
"uid": "12345" | |
}, | |
"parameter": { | |
"chat": { | |
"domain": "general", | |
"temperature": 0.5, | |
"max_tokens": 1024, | |
} | |
}, | |
"payload": { | |
"message": { | |
# 如果想获取结合上下文的回答,需要开发者每次将历史问答信息一起传给服务端,如下示例 | |
# 注意:text里面的所有content内容加一起的tokens需要控制在8192以内,开发者如有较长对话需求,需要适当裁剪历史信息 | |
"text": [ | |
{"role": "user", "content": "你是谁"} # 用户的历史问题 | |
{"role": "assistant", "content": "....."} # AI的历史回答结果 | |
# ....... 省略的历史对话 | |
{"role": "user", "content": "你会做什么"} # 最新的一条问题,如无需上下文,可只传最新一条问题 | |
] | |
} | |
} | |
} | |
*/ | |
/* | |
response sample: | |
{ | |
"header":{ | |
"code":0, | |
"message":"Success", | |
"sid":"cht000cb087@dx18793cd421fb894542", | |
"status":2 | |
}, | |
"payload":{ | |
"choices":{ | |
"status":2, | |
"seq":0, | |
"text":[ | |
{ | |
"content":"我可以帮助你的吗?", | |
"role":"assistant", | |
"index":0 | |
} | |
] | |
}, | |
"usage":{ | |
"text":{ | |
"question_tokens":4, | |
"prompt_tokens":5, | |
"completion_tokens":9, | |
"total_tokens":14 | |
} | |
} | |
} | |
} | |
*/ | |
val webSocketSession = getWebSocketSession() | |
webSocketSession.apply { | |
send(buildJsonObject { | |
put("header", buildJsonObject { | |
put("app_id", appId) | |
put("uid", uid) | |
}) | |
put("parameter", buildJsonObject { | |
put("chat", buildJsonObject { | |
put("domain", getDomain()) | |
put("temperature", 0.5) | |
put("max_tokens", 1024) | |
}) | |
}) | |
put("payload", buildJsonObject { | |
put("message", buildJsonObject { | |
put("text", JsonArray( | |
session.chatHistory.map { msg -> | |
buildJsonObject { | |
put("role", msg.role.name.lowercase()) | |
put("content", msg.text) | |
} | |
} | |
)) | |
}) | |
}) | |
}.toString()) | |
var end = false | |
var newSession = session | |
while (!end) { | |
val frame = | |
incoming.receiveCatching().getOrNull() | |
if (frame is Frame.Text) { | |
val response = frame.readText() | |
println(response) | |
newSession = responseToChatSession(newSession, response) | |
println("${newSession.chatHistory.size}") | |
_producer.trySend(newSession) | |
end = newSession.status == MessageStatus.END | |
} | |
} | |
} | |
} | |
} | |
private lateinit var _producer: ProducerScope<ChatSession<ChatMessage>> | |
override val response: Flow<ChatSession<ChatMessage>> = | |
callbackFlow { | |
_producer = this | |
awaitClose() | |
println("end") | |
} | |
override fun imageGenerate(prompt: String): Flow<ChatSession<ChatMessage>> { | |
TODO() | |
} | |
override fun responseToChatSession( | |
sendingSession: ChatSession<ChatMessage>, | |
responseBody: String | |
): ChatSession<ChatMessage> { | |
val newMessage: ChatMessage = responseBody.toMessage() | |
val lastMessage = sendingSession.chatHistory.last() | |
return if (lastMessage.status != MessageStatus.END) { | |
val newLastMessage = ChatMessage( | |
text = lastMessage.text + newMessage.text, | |
status = newMessage.status, | |
role = lastMessage.role | |
) | |
sendingSession.copy(chatHistory = sendingSession.chatHistory.toMutableList().apply { | |
removeLast() | |
add(newLastMessage) | |
}, status = newMessage.status) | |
} else { | |
sendingSession.copy(chatHistory = sendingSession.chatHistory.toMutableList().apply { | |
add(newMessage) | |
}, status = newMessage.status) | |
} | |
} | |
private fun String.toMessage(): ChatMessage { | |
return runCatching { | |
val element = Json.parseToJsonElement(this) | |
val choices = element.jsonObject["payload"]!!.jsonObject["choices"]!! | |
val statusCode: Int = choices.jsonObject["status"]?.jsonPrimitive?.int!! | |
val message = choices.jsonObject["text"]!!.jsonArray[0] | |
val status: MessageStatus = when (statusCode) { | |
0 -> MessageStatus.BEGIN | |
1 -> MessageStatus.CONTENT | |
else -> MessageStatus.END | |
} | |
val text: String = message.jsonObject["content"]!!.jsonPrimitive.content | |
ChatMessage(text = text, role = Role2.Assistant, status = status) | |
}.getOrElse { err -> | |
ChatMessage( | |
text = "error: \nmsg: $this \nerror:${err.cause} ${err.message}", | |
role = Role2.Assistant, | |
status = MessageStatus.END | |
) | |
} | |
} | |
} | |
/** | |
* LLMProvider provide only session-free interfaces | |
* | |
* It's LLMProvider's responsibility to maintain the network connection | |
* But the chat session is maintained by the client | |
* | |
* The implementations should provide a way to convert the response to a ChatSession | |
*/ | |
abstract class LLMProvider<M : ChatMessage> { | |
abstract val response: Flow<ChatSession<M>> | |
abstract fun textComplete(session: ChatSession<ChatMessage>) | |
abstract fun imageGenerate(prompt: String): Flow<ChatSession<M>> | |
protected abstract fun responseToChatSession(sendingSession: ChatSession<M>, responseBody: String): ChatSession<M> | |
} | |
enum class Role2 { | |
USER, BOT, Assistant | |
} | |
/** | |
* 后台返回消息状态:[BEGIN] 刚开始返回 [CONTENT] 中间内容 [END] 内容结束 | |
*/ | |
enum class MessageStatus { | |
BEGIN, CONTENT, END | |
} | |
open class ChatMessage( | |
val text: String, | |
val role: Role2, | |
val status: MessageStatus = MessageStatus.CONTENT, | |
) | |
data class ChatSession<out T : ChatMessage>( | |
val conversionName: String, | |
val chatHistory: List<T>, | |
/** | |
* 当前会话状态,为 [MessageStatus.END] 时才允许发下一个对话 | |
*/ | |
var status: MessageStatus = MessageStatus.END | |
) | |
fun <T: ChatMessage> ChatSession<T>.canSendNextMessage() = this.status == MessageStatus.END | |
fun <T: ChatMessage> ChatSession<T>.appendMessage(newMessage: T): ChatSession<T> = | |
this.copy(chatHistory = this.chatHistory.toMutableList().apply { | |
add(newMessage) | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment