Created
October 6, 2023 20:16
-
-
Save tianyu/36ecf8e5a2a936f008ea61f97e04abc0 to your computer and use it in GitHub Desktop.
A squareup.wire.GrpcClient for Kotlin/JS using ktor and grpc-web
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.squareup.wire.GrpcCall | |
import com.squareup.wire.GrpcClient | |
import com.squareup.wire.GrpcMethod | |
import com.squareup.wire.GrpcStreamingCall | |
import com.squareup.wire.MessageSink | |
import com.squareup.wire.MessageSource | |
import io.ktor.client.HttpClient | |
import io.ktor.client.plugins.expectSuccess | |
import io.ktor.client.request.accept | |
import io.ktor.client.request.header | |
import io.ktor.client.request.post | |
import io.ktor.client.request.setBody | |
import io.ktor.client.request.url | |
import io.ktor.client.statement.bodyAsChannel | |
import io.ktor.http.ContentType | |
import io.ktor.http.content.ChannelWriterContent | |
import io.ktor.utils.io.readAvailable | |
import io.ktor.utils.io.writeFully | |
import kotlinx.coroutines.CoroutineScope | |
import kotlinx.coroutines.Deferred | |
import kotlinx.coroutines.GlobalScope | |
import kotlinx.coroutines.Job | |
import kotlinx.coroutines.async | |
import kotlinx.coroutines.channels.Channel | |
import kotlinx.coroutines.channels.ReceiveChannel | |
import kotlinx.coroutines.channels.SendChannel | |
import kotlinx.coroutines.channels.consumeEach | |
import kotlinx.coroutines.launch | |
import okio.IOException | |
import okio.Timeout | |
class GrpcWebClient(private val httpClient: HttpClient, private val basePath: String = ""): GrpcClient() { | |
override fun <S : Any, R : Any> newCall(method: GrpcMethod<S, R>) = GrpcWebCall(method) | |
override fun <S : Any, R : Any> newStreamingCall(method: GrpcMethod<S, R>) = GrpcWebStreamingCall(method) | |
/** | |
* From https://grpc.github.io/grpc/core/md_doc_statuscodes.html | |
*/ | |
private enum class GrpcStatus { | |
OK, | |
CANCELLED, | |
UNKNOWN, | |
INVALID_ARGUMENT, | |
DEADLINE_EXCEEDED, | |
NOT_FOUND, | |
ALREADY_EXISTS, | |
PERMISSION_DENIED, | |
RESOURCE_EXHAUSTED, | |
FAILED_PRECONDITION, | |
ABORTED, | |
OUT_OF_RANGE, | |
UNIMPLEMENTED, | |
INTERNAL, | |
UNAVAILABLE, | |
DATA_LOSS, | |
UNAUTHENTICATED, | |
} | |
inner class GrpcWebCall<S: Any, R: Any>(override val method: GrpcMethod<S, R>, override val timeout: Timeout = Timeout.NONE): GrpcCall<S, R> { | |
override var requestMetadata: Map<String, String> = mapOf() | |
override val responseMetadata: Map<String, String>? = null | |
private lateinit var response: Deferred<R> | |
override fun clone() = GrpcWebCall(method, timeout) | |
override fun cancel() = response.cancel() | |
override fun isCanceled(): Boolean = response.isCancelled | |
override fun isExecuted(): Boolean = response.isCompleted | |
override fun executeBlocking(request: S): R { | |
throw UnsupportedOperationException("Blocking execution is not supported for JS") | |
} | |
override suspend fun execute(request: S): R { | |
GlobalScope.async { | |
val path = basePath + method.path | |
val requestAdapter = method.requestAdapter | |
val responseAdapter = method.responseAdapter | |
val post = httpClient.post { | |
url(path) | |
header("TE", "trailers") | |
accept(ContentType("application", "grpc")) | |
setBody(ChannelWriterContent( | |
body = { | |
writeByte(0) | |
writeInt(requestAdapter.encodedSize(request)) | |
writeFully(requestAdapter.encode(request)) | |
}, | |
contentType = ContentType("application", "grpc") | |
)) | |
expectSuccess = true | |
} | |
val grpcStatus = post.headers["grpc-status"]?.toIntOrNull()?.let(GrpcStatus.values()::getOrNull) ?: GrpcStatus.OK | |
if (grpcStatus != GrpcStatus.OK) { | |
val grpcMessage = post.headers["grpc-message"] ?: "<no-message>" | |
throw IllegalStateException("$grpcStatus: $grpcMessage") | |
} | |
val body = post.bodyAsChannel() | |
val compressed = body.readByte() | |
if (compressed != 0.toByte()) { | |
throw UnsupportedOperationException("Response body is compressed!") | |
} | |
val length = body.readInt() | |
if (length < 0) { | |
throw UnsupportedOperationException("Length overflow!") | |
} | |
val buffer = ByteArray(length) | |
val lengthRead = body.readAvailable(buffer) | |
if (lengthRead < length) { | |
throw UnsupportedOperationException("Oh no! Couldn't read all the bytes") | |
} | |
responseAdapter.decode(buffer) | |
}.let { | |
response = it | |
return it.await() | |
} | |
} | |
override fun enqueue(request: S, callback: GrpcCall.Callback<S, R>) { | |
GlobalScope.launch { | |
try { | |
callback.onSuccess(this@GrpcWebCall, execute(request)) | |
} catch (error: IOException) { | |
callback.onFailure(this@GrpcWebCall, error) | |
} | |
} | |
} | |
} | |
inner class GrpcWebStreamingCall<S: Any, R: Any>(override val method: GrpcMethod<S, R>, override val timeout: Timeout = Timeout.NONE): GrpcStreamingCall<S, R> { | |
override var requestMetadata: Map<String, String> = mapOf() | |
override val responseMetadata: Map<String, String>? = null | |
private lateinit var job: Job | |
override fun clone(): GrpcStreamingCall<S, R> = GrpcWebStreamingCall(method, timeout) | |
override fun cancel() = job.cancel() | |
override fun isCanceled(): Boolean = job.isCancelled | |
override fun isExecuted(): Boolean = job.isCompleted | |
override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> { | |
throw UnsupportedOperationException("Blocking execution is not supported for JS") | |
} | |
@Deprecated( | |
"Provide a scope, preferably not GlobalScope", | |
replaceWith = ReplaceWith("executeIn(GlobalScope)", "kotlinx.coroutines.GlobalScope"), | |
level = DeprecationLevel.WARNING | |
) | |
override fun execute(): Pair<SendChannel<S>, ReceiveChannel<R>> { | |
return executeIn(GlobalScope) | |
} | |
override fun executeIn(scope: CoroutineScope): Pair<SendChannel<S>, ReceiveChannel<R>> { | |
val path = basePath + method.path | |
val requestAdapter = method.requestAdapter | |
val responseAdapter = method.responseAdapter | |
val requests = Channel<S>() | |
val responses = Channel<R>() | |
job = scope.launch { | |
val post = httpClient.post { | |
url(path) | |
header("TE", "trailers") | |
accept(ContentType("application", "grpc")) | |
setBody(ChannelWriterContent( | |
body = { | |
requests.consumeEach { request -> | |
writeByte(0) | |
writeInt(requestAdapter.encodedSize(request)) | |
writeFully(requestAdapter.encode(request)) | |
} | |
}, | |
contentType = ContentType("application", "grpc") | |
)) | |
expectSuccess = true | |
} | |
val grpcStatus = post.headers["grpc-status"]?.toIntOrNull()?.let(GrpcStatus.values()::getOrNull) ?: GrpcStatus.OK | |
if (grpcStatus != GrpcStatus.OK) { | |
val grpcMessage = post.headers["grpc-message"] ?: "<no-message>" | |
throw IllegalStateException("$grpcStatus: $grpcMessage") | |
} | |
val body = post.bodyAsChannel() | |
try { | |
while (!body.isClosedForRead) { | |
val compressed = body.readByte() | |
if (compressed != 0.toByte()) { | |
throw UnsupportedOperationException("Response body is compressed!") | |
} | |
val length = body.readInt() | |
if (length < 0) { | |
throw UnsupportedOperationException("Length overflow!") | |
} | |
val buffer = ByteArray(length) | |
val lengthRead = body.readAvailable(buffer) | |
if (lengthRead < length) { | |
throw UnsupportedOperationException("Oh no! Couldn't read all the bytes") | |
} | |
responses.send(responseAdapter.decode(buffer)) | |
} | |
} finally { | |
responses.close() | |
} | |
} | |
return requests to responses | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment