Skip to content

Instantly share code, notes, and snippets.

@watabee
Created February 23, 2019 05:18
Show Gist options
  • Save watabee/7dd034379d8fe30409ed88de2296bbea to your computer and use it in GitHub Desktop.
Save watabee/7dd034379d8fe30409ed88de2296bbea to your computer and use it in GitHub Desktop.
A ktor client engine for sharing OkHttpClient
import io.ktor.client.call.HttpClientCall
import io.ktor.client.call.HttpEngineCall
import io.ktor.client.call.UnsupportedContentTypeException
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.HttpClientEngineConfig
import io.ktor.client.engine.HttpClientEngineFactory
import io.ktor.client.engine.HttpClientJvmEngine
import io.ktor.client.engine.mergeHeaders
import io.ktor.client.request.DefaultHttpRequest
import io.ktor.client.request.HttpRequestData
import io.ktor.client.response.HttpResponse
import io.ktor.http.Headers
import io.ktor.http.HttpProtocolVersion
import io.ktor.http.HttpStatusCode
import io.ktor.http.content.OutgoingContent
import io.ktor.util.InternalAPI
import io.ktor.util.KtorExperimentalAPI
import io.ktor.util.cio.KtorDefaultPool
import io.ktor.util.cio.toByteReadChannel
import io.ktor.util.date.GMTDate
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.io.ByteReadChannel
import kotlinx.coroutines.io.jvm.javaio.toInputStream
import kotlinx.coroutines.io.writer
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.withContext
import okhttp3.Call
import okhttp3.Callback
import okhttp3.MediaType
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.RequestBody
import okhttp3.Response
import okio.BufferedSink
import okio.Okio
import java.io.IOException
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
class SharedOkHttp(
private val client: OkHttpClient
) : HttpClientEngineFactory<HttpClientEngineConfig> {
override fun create(block: HttpClientEngineConfig.() -> Unit): HttpClientEngine =
SharedOkHttpEngine(client, HttpClientEngineConfig().apply(block))
}
/**
* A ktor client engine for sharing OkHttpClient.
*/
@UseExperimental(InternalAPI::class, KtorExperimentalAPI::class)
class SharedOkHttpEngine(
private val client: OkHttpClient,
override val config: HttpClientEngineConfig
) : HttpClientJvmEngine("ktor-shared-okhttp") {
override suspend fun execute(call: HttpClientCall, data: HttpRequestData): HttpEngineCall {
val request = DefaultHttpRequest(call, data)
val requestTime = GMTDate()
val callContext = createCallContext()
val builder = Request.Builder()
with(builder) {
url(request.url.toString())
mergeHeaders(request.headers, request.content) { key, value ->
addHeader(key, value)
}
method(request.method.value, request.content.convertToOkHttpBody(callContext))
}
val response = client.execute(builder.build())
val body = response.body()
callContext[Job]?.invokeOnCompletion { body?.close() }
val responseContent = withContext(callContext) {
body?.byteStream()?.toByteReadChannel(
context = callContext,
pool = KtorDefaultPool
) ?: ByteReadChannel.Empty
}
return HttpEngineCall(
request,
OkHttpResponse(response, call, requestTime, responseContent, callContext)
)
}
}
private fun OutgoingContent.convertToOkHttpBody(callContext: CoroutineContext): RequestBody? =
when (this) {
is OutgoingContent.ByteArrayContent -> RequestBody.create(null, bytes())
is OutgoingContent.ReadChannelContent -> StreamRequestBody(contentLength) { readFrom() }
is OutgoingContent.WriteChannelContent -> {
StreamRequestBody(contentLength) {
GlobalScope.writer(callContext) { writeTo(channel) }.channel
}
}
is OutgoingContent.NoContent -> null
else -> throw UnsupportedContentTypeException(this)
}
private class OkHttpResponse(
private val response: Response,
override val call: HttpClientCall,
override val requestTime: GMTDate,
override val content: ByteReadChannel,
override val coroutineContext: CoroutineContext
) : HttpResponse {
override val headers: Headers = object : Headers {
override val caseInsensitiveName: Boolean = false
private val instance = response.headers()!!
override fun getAll(name: String): List<String>? = instance.values(name)
override fun names(): Set<String> = instance.names()
override fun entries(): Set<Map.Entry<String, List<String>>> = instance.toMultimap().entries
override fun isEmpty(): Boolean = instance.size() == 0
}
override val status: HttpStatusCode = HttpStatusCode(response.code(), response.message())
override val version: HttpProtocolVersion = response.protocol().fromOkHttp()
override val responseTime: GMTDate = GMTDate()
}
@Suppress("DEPRECATION")
private fun Protocol.fromOkHttp(): HttpProtocolVersion = when (this) {
Protocol.HTTP_1_0 -> HttpProtocolVersion.HTTP_1_0
Protocol.HTTP_1_1 -> HttpProtocolVersion.HTTP_1_1
Protocol.SPDY_3 -> HttpProtocolVersion.SPDY_3
Protocol.HTTP_2 -> HttpProtocolVersion.HTTP_2_0
Protocol.H2_PRIOR_KNOWLEDGE -> HttpProtocolVersion.HTTP_2_0
Protocol.QUIC -> HttpProtocolVersion.QUIC
}
private class StreamRequestBody(
private val contentLength: Long?,
private val block: () -> ByteReadChannel
) : RequestBody() {
override fun contentType(): MediaType? = null
override fun writeTo(sink: BufferedSink) {
Okio.source(block().toInputStream()).use {
sink.writeAll(it)
}
}
override fun contentLength(): Long = contentLength ?: -1
}
private suspend fun OkHttpClient.execute(request: Request): Response = suspendCancellableCoroutine {
val call = newCall(request)
val callback = object : Callback {
override fun onFailure(call: Call, cause: IOException) {
if (!call.isCanceled) it.resumeWithException(cause)
}
override fun onResponse(call: Call, response: Response) {
if (!call.isCanceled) it.resume(response)
}
}
call.enqueue(callback)
it.invokeOnCancellation {
call.cancel()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment