Skip to content

Instantly share code, notes, and snippets.

@Exerosis
Created June 1, 2018 08:28
Show Gist options
  • Save Exerosis/37f8ab85866b4911c7d23930987dcd46 to your computer and use it in GitHub Desktop.
Save Exerosis/37f8ab85866b4911c7d23930987dcd46 to your computer and use it in GitHub Desktop.
package com.mynt.network.implementation
import com.mynt.network.Connection
import com.mynt.network.Read
import com.mynt.network.Write
import java.nio.ByteBuffer
import kotlin.coroutines.experimental.Continuation
import kotlin.coroutines.experimental.EmptyCoroutineContext
import kotlin.coroutines.experimental.intrinsics.suspendCoroutineUninterceptedOrReturn
val BYTE = { buffer: ByteBuffer -> buffer.get() }
val SHORT = { buffer: ByteBuffer -> buffer.short }
val INT = { buffer: ByteBuffer -> buffer.int }
val FLOAT = { buffer: ByteBuffer -> buffer.float }
val LONG = { buffer: ByteBuffer -> buffer.long }
val DOUBLE = { buffer: ByteBuffer -> buffer.double }
suspend inline fun <Type> continued(
crossinline block: (Continuation<Type>) -> Any?
) = suspendCoroutineUninterceptedOrReturn(block)
fun <Type> continuation(callback: (Type) -> Unit) = object : Continuation<Type> {
override val context = EmptyCoroutineContext
override fun resume(value: Type) = callback(value)
override fun resumeWithException(exception: Throwable) = throw exception
}
//TODO little bit of a weird way of doing this, but I think it makes sense.
inline fun Connection.read(
block: Read.() -> Unit
) = block(read)
inline fun Connection.write(
block: Write.() -> Unit
) = block(write)
open class Holder<Type> {
protected var value: Type? = null
fun hold(value: Type): Boolean {
if (this.value != null)
return false
this.value = value
return true
}
fun release(): Type {
val temp = value!!
value = null
return value!!
}
}
open class InProgressException : IllegalStateException()
package com.mynt.network.implementation.sequential
import com.mynt.network.Read
import com.mynt.network.ReadCoordinator
import com.mynt.network.implementation.*
import java.nio.ByteBuffer
//TODO this probably can't stay static forever.
open class SequentialRead(
private val read: ReadCoordinator,
private val buffer: ByteBuffer
) : Read {
//--Complex--
override suspend fun array(
array: ByteArray,
amount: Int,
offset: Int
) = continued<ByteArray> {
read.array(buffer, array, amount, offset, it)
}
override suspend fun buffer(
buffer: ByteBuffer
) = continued<ByteBuffer> {
read.buffer(this.buffer, buffer, it)
}
//--Primitive--
override suspend fun byte() = continued<Byte> {
read.number(buffer, 1, BYTE, it)
}
override suspend fun short() = continued<Short> {
read.number(buffer, 2, SHORT, it)
}
override suspend fun int() = continued<Int> {
read.number(buffer, 4, INT, it)
}
override suspend fun float() = continued<Float> {
read.number(buffer, 4, FLOAT, it)
}
override suspend fun long() = continued<Long> {
read.number(buffer, 8, LONG, it)
}
override suspend fun double() = continued<Double> {
read.number(buffer, 8, DOUBLE, it)
}
}
package com.mynt.network.implementation.sequential
import com.mynt.network.ReadCoordinator
import java.nio.ByteBuffer
import java.nio.channels.CompletionHandler
import kotlin.coroutines.experimental.Continuation
import kotlin.coroutines.experimental.intrinsics.COROUTINE_SUSPENDED
class SequentialReadCoordinator(
private val read: (ByteBuffer, CompletionHandler<Int, ByteBuffer>) -> Unit
) : ReadCoordinator {
class ReadInProgressException : IllegalStateException()
private abstract class Handler<Type> : CompletionHandler<Int, ByteBuffer> {
var required = 0
var continuation: Continuation<Type>? = null
override fun failed(reason: Throwable, buffer: ByteBuffer) {
val temp = continuation!!
continuation = null
temp.resumeWithException(reason)
}
fun resume(value: Type?) {
val temp = continuation!!
continuation = null
temp.resume(value!!)
}
}
private val arrayHandler: (ByteBuffer, ByteArray, Int, Int, Continuation<ByteArray>) -> Any = object :
(ByteBuffer, ByteArray, Int, Int, Continuation<ByteArray>) -> Any,
Handler<ByteArray>() {
var offset = 0
var array: ByteArray? = null
override fun invoke(
using: ByteBuffer,
array: ByteArray,
offset: Int,
amount: Int,
continuation: Continuation<ByteArray>
): Any {
if (this.continuation != null)
throw ReadInProgressException()
val remaining = using.remaining()
return if (remaining >= amount) {
using.get(array, offset, amount)
array
} else {
required = amount - remaining
this.continuation = continuation
this.array = array
this.offset = offset - remaining
using.flip()
if (this.offset != offset)
read(using, this)
else
completed(remaining, using)
COROUTINE_SUSPENDED
}
}
override fun completed(count: Int, buffer: ByteBuffer) {
required -= count
if (required < 1) {
buffer.flip()
buffer.get(array, offset, required).clear()
resume(array)
} else {
if (buffer.remaining() < required) {
val remaining = buffer.position()
buffer.flip()
buffer.get(array, offset, remaining).clear()
offset += remaining
} else
read(buffer, this)
}
}
}
private val bufferHandler: (ByteBuffer, ByteBuffer, Continuation<ByteBuffer>) -> Any = object :
(ByteBuffer, ByteBuffer, Continuation<ByteBuffer>) -> Any,
Handler<ByteBuffer>() {
override fun invoke(
using: ByteBuffer,
destination: ByteBuffer,
continuation: Continuation<ByteBuffer>
): Any {
//TODO we could use required instead, but maybe nulling out is good?
if (this.continuation != null)
throw ReadInProgressException()
if (using.hasRemaining())
destination.put(using)
required = destination.remaining()
return if (required < 1)
destination
else {
this.continuation = continuation
read(destination, this)
COROUTINE_SUSPENDED
}
}
override fun completed(count: Int, destination: ByteBuffer) {
required -= count
if (required < 1) {
resume(destination)
} else {
read(destination, this)
}
}
}
private val primitiveHandler: (ByteBuffer, Int, (ByteBuffer) -> Number, Continuation<Number>) -> Any = object :
(ByteBuffer, Int, (ByteBuffer) -> Number, Continuation<Number>) -> Any,
Handler<Number>() {
var marked = false
var converter: ((ByteBuffer) -> Number)? = null
override fun invoke(
using: ByteBuffer,
amount: Int,
converter: (ByteBuffer) -> Number,
continuation: Continuation<Number>
): Any {
if (this.continuation != null)
throw ReadInProgressException()
val remaining = using.remaining()
return if (remaining >= amount)
converter(using.flip() as ByteBuffer)
else {
this.continuation = continuation
this.converter = converter
required = amount - remaining
if (remaining == 0)
using.clear()
else {
val capacity = using.capacity()
val limit = using.limit()
if (capacity - limit < required) {
using.compact()
} else {
marked = true
using.mark().position(limit).limit(capacity)
}
read(using, this)
}
COROUTINE_SUSPENDED
}
}
override fun completed(count: Int, buffer: ByteBuffer) {
required -= count
if (required < 1) {
if (marked) {
marked = false
buffer.reset()
} else
buffer.flip()
resume(converter!!.invoke(buffer))
converter = null
}
}
}
override fun buffer(
using: ByteBuffer,
destination: ByteBuffer,
continuation: Continuation<ByteBuffer>
) = bufferHandler(using, destination, continuation)
override fun array(
using: ByteBuffer,
array: ByteArray,
offset: Int,
amount: Int,
continuation: Continuation<ByteArray>
) = arrayHandler(using, array, offset, amount, continuation)
override fun <Type : Number> number(
using: ByteBuffer,
amount: Int,
converter: (ByteBuffer) -> Type,
continuation: Continuation<Type>
) = primitiveHandler(using, amount, converter, continuation as Continuation<Number>)
}
package com.mynt.network.providers
import com.mynt.network.Connection
import com.mynt.network.Provider
import com.mynt.network.implementation.continued
import com.mynt.network.implementation.sequential.SequentialRead
import com.mynt.network.implementation.sequential.SequentialReadCoordinator
import com.mynt.network.implementation.sequential.SequentialWrite
import com.mynt.network.implementation.sequential.SequentialWriteCoordinator
import kotlinx.coroutines.experimental.async
import java.lang.Boolean.TRUE
import java.net.SocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousChannelGroup
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MILLISECONDS
import kotlin.coroutines.experimental.Continuation
class TCPSocketProvider(
private val group: AsynchronousChannelGroup,
private val allocator: () -> ByteBuffer
) : Provider {
private val servers = HashMap<SocketAddress, AsynchronousServerSocketChannel>()
//TODO is this really a win?
private val serverFactory = { address: SocketAddress ->
group.provider().openAsynchronousServerSocketChannel(group).bind(address)
}
private open class Handler(
allocator: () -> ByteBuffer
) : Connection {
open lateinit var channel: AsynchronousSocketChannel
override val read = SequentialRead(SequentialReadCoordinator { buffer, handler ->
channel.read(buffer, buffer, handler)
}, allocator())
override val write = SequentialWrite(SequentialWriteCoordinator { buffer, handler ->
channel.write(buffer, buffer, handler)
}, allocator())
override val isOpen
get() = channel.isOpen
override fun close() = channel.close()
}
//--Accept--
private class AcceptHandler(
allocator: () -> ByteBuffer
) : Handler(allocator), CompletionHandler<AsynchronousSocketChannel, Continuation<Connection>> {
override fun completed(channel: AsynchronousSocketChannel, continuation: Continuation<Connection>) {
this.channel = channel
continuation.resume(this)
}
override fun failed(reason: Throwable, continuation: Continuation<Connection>) =
continuation.resumeWithException(reason)
}
override suspend fun accept(address: SocketAddress) = continued<Connection> {
servers.computeIfAbsent(address, serverFactory).accept(it, AcceptHandler(allocator))
}
//--Connect--
private class ConnectHandler(
allocator: () -> ByteBuffer,
override var channel: AsynchronousSocketChannel
) : Handler(allocator), CompletionHandler<Void?, Continuation<Connection>> {
override fun completed(ignored: Void?, continuation: Continuation<Connection>) =
continuation.resume(this)
override fun failed(reason: Throwable, continuation: Continuation<Connection>) =
continuation.resumeWithException(reason)
}
override suspend fun connect(address: SocketAddress) = continued<Connection> {
val channel = group.provider().openAsynchronousSocketChannel(group)
channel.connect(address, it, ConnectHandler(allocator, channel))
}
//--State--
override val isOpen
get() = !group.isTerminated
override fun close() = group.shutdown()
//TODO Maybe just make close do this?
suspend fun awaitClose(
period: Long = Long.MAX_VALUE,
units: TimeUnit = MILLISECONDS
) = continued<Boolean> {
close()
async(it.context) {
try {
group.awaitTermination(period, units)
it.resume(TRUE)
} catch (reason: Exception) {
it.resumeWithException(reason)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment