Skip to content

Instantly share code, notes, and snippets.

@tianyu
Created October 6, 2023 20:16
Show Gist options
  • Save tianyu/36ecf8e5a2a936f008ea61f97e04abc0 to your computer and use it in GitHub Desktop.
Save tianyu/36ecf8e5a2a936f008ea61f97e04abc0 to your computer and use it in GitHub Desktop.
A squareup.wire.GrpcClient for Kotlin/JS using ktor and grpc-web
import com.squareup.wire.GrpcCall
import com.squareup.wire.GrpcClient
import com.squareup.wire.GrpcMethod
import com.squareup.wire.GrpcStreamingCall
import com.squareup.wire.MessageSink
import com.squareup.wire.MessageSource
import io.ktor.client.HttpClient
import io.ktor.client.plugins.expectSuccess
import io.ktor.client.request.accept
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.request.url
import io.ktor.client.statement.bodyAsChannel
import io.ktor.http.ContentType
import io.ktor.http.content.ChannelWriterContent
import io.ktor.utils.io.readAvailable
import io.ktor.utils.io.writeFully
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.coroutines.launch
import okio.IOException
import okio.Timeout
class GrpcWebClient(private val httpClient: HttpClient, private val basePath: String = ""): GrpcClient() {
override fun <S : Any, R : Any> newCall(method: GrpcMethod<S, R>) = GrpcWebCall(method)
override fun <S : Any, R : Any> newStreamingCall(method: GrpcMethod<S, R>) = GrpcWebStreamingCall(method)
/**
* From https://grpc.github.io/grpc/core/md_doc_statuscodes.html
*/
private enum class GrpcStatus {
OK,
CANCELLED,
UNKNOWN,
INVALID_ARGUMENT,
DEADLINE_EXCEEDED,
NOT_FOUND,
ALREADY_EXISTS,
PERMISSION_DENIED,
RESOURCE_EXHAUSTED,
FAILED_PRECONDITION,
ABORTED,
OUT_OF_RANGE,
UNIMPLEMENTED,
INTERNAL,
UNAVAILABLE,
DATA_LOSS,
UNAUTHENTICATED,
}
inner class GrpcWebCall<S: Any, R: Any>(override val method: GrpcMethod<S, R>, override val timeout: Timeout = Timeout.NONE): GrpcCall<S, R> {
override var requestMetadata: Map<String, String> = mapOf()
override val responseMetadata: Map<String, String>? = null
private lateinit var response: Deferred<R>
override fun clone() = GrpcWebCall(method, timeout)
override fun cancel() = response.cancel()
override fun isCanceled(): Boolean = response.isCancelled
override fun isExecuted(): Boolean = response.isCompleted
override fun executeBlocking(request: S): R {
throw UnsupportedOperationException("Blocking execution is not supported for JS")
}
override suspend fun execute(request: S): R {
GlobalScope.async {
val path = basePath + method.path
val requestAdapter = method.requestAdapter
val responseAdapter = method.responseAdapter
val post = httpClient.post {
url(path)
header("TE", "trailers")
accept(ContentType("application", "grpc"))
setBody(ChannelWriterContent(
body = {
writeByte(0)
writeInt(requestAdapter.encodedSize(request))
writeFully(requestAdapter.encode(request))
},
contentType = ContentType("application", "grpc")
))
expectSuccess = true
}
val grpcStatus = post.headers["grpc-status"]?.toIntOrNull()?.let(GrpcStatus.values()::getOrNull) ?: GrpcStatus.OK
if (grpcStatus != GrpcStatus.OK) {
val grpcMessage = post.headers["grpc-message"] ?: "<no-message>"
throw IllegalStateException("$grpcStatus: $grpcMessage")
}
val body = post.bodyAsChannel()
val compressed = body.readByte()
if (compressed != 0.toByte()) {
throw UnsupportedOperationException("Response body is compressed!")
}
val length = body.readInt()
if (length < 0) {
throw UnsupportedOperationException("Length overflow!")
}
val buffer = ByteArray(length)
val lengthRead = body.readAvailable(buffer)
if (lengthRead < length) {
throw UnsupportedOperationException("Oh no! Couldn't read all the bytes")
}
responseAdapter.decode(buffer)
}.let {
response = it
return it.await()
}
}
override fun enqueue(request: S, callback: GrpcCall.Callback<S, R>) {
GlobalScope.launch {
try {
callback.onSuccess(this@GrpcWebCall, execute(request))
} catch (error: IOException) {
callback.onFailure(this@GrpcWebCall, error)
}
}
}
}
inner class GrpcWebStreamingCall<S: Any, R: Any>(override val method: GrpcMethod<S, R>, override val timeout: Timeout = Timeout.NONE): GrpcStreamingCall<S, R> {
override var requestMetadata: Map<String, String> = mapOf()
override val responseMetadata: Map<String, String>? = null
private lateinit var job: Job
override fun clone(): GrpcStreamingCall<S, R> = GrpcWebStreamingCall(method, timeout)
override fun cancel() = job.cancel()
override fun isCanceled(): Boolean = job.isCancelled
override fun isExecuted(): Boolean = job.isCompleted
override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
throw UnsupportedOperationException("Blocking execution is not supported for JS")
}
@Deprecated(
"Provide a scope, preferably not GlobalScope",
replaceWith = ReplaceWith("executeIn(GlobalScope)", "kotlinx.coroutines.GlobalScope"),
level = DeprecationLevel.WARNING
)
override fun execute(): Pair<SendChannel<S>, ReceiveChannel<R>> {
return executeIn(GlobalScope)
}
override fun executeIn(scope: CoroutineScope): Pair<SendChannel<S>, ReceiveChannel<R>> {
val path = basePath + method.path
val requestAdapter = method.requestAdapter
val responseAdapter = method.responseAdapter
val requests = Channel<S>()
val responses = Channel<R>()
job = scope.launch {
val post = httpClient.post {
url(path)
header("TE", "trailers")
accept(ContentType("application", "grpc"))
setBody(ChannelWriterContent(
body = {
requests.consumeEach { request ->
writeByte(0)
writeInt(requestAdapter.encodedSize(request))
writeFully(requestAdapter.encode(request))
}
},
contentType = ContentType("application", "grpc")
))
expectSuccess = true
}
val grpcStatus = post.headers["grpc-status"]?.toIntOrNull()?.let(GrpcStatus.values()::getOrNull) ?: GrpcStatus.OK
if (grpcStatus != GrpcStatus.OK) {
val grpcMessage = post.headers["grpc-message"] ?: "<no-message>"
throw IllegalStateException("$grpcStatus: $grpcMessage")
}
val body = post.bodyAsChannel()
try {
while (!body.isClosedForRead) {
val compressed = body.readByte()
if (compressed != 0.toByte()) {
throw UnsupportedOperationException("Response body is compressed!")
}
val length = body.readInt()
if (length < 0) {
throw UnsupportedOperationException("Length overflow!")
}
val buffer = ByteArray(length)
val lengthRead = body.readAvailable(buffer)
if (lengthRead < length) {
throw UnsupportedOperationException("Oh no! Couldn't read all the bytes")
}
responses.send(responseAdapter.decode(buffer))
}
} finally {
responses.close()
}
}
return requests to responses
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment