Skip to content

Instantly share code, notes, and snippets.

@Exerosis
Created June 3, 2018 15:09
Show Gist options
  • Save Exerosis/b767c51d7475a7a37f9aaa723fa48b33 to your computer and use it in GitHub Desktop.
Save Exerosis/b767c51d7475a7a37f9aaa723fa48b33 to your computer and use it in GitHub Desktop.
package com.mynt.network.implementation.sequential
import com.mynt.network.ReadCoordinator
import com.mynt.network.implementation.Holder
import com.mynt.network.implementation.InProgressException
import java.nio.ByteBuffer
import java.nio.channels.CompletionHandler
import kotlin.coroutines.experimental.Continuation
import kotlin.coroutines.experimental.intrinsics.COROUTINE_SUSPENDED
class SequentialReadCoordinator(
private val read: (ByteBuffer, CompletionHandler<Int, ByteBuffer>) -> Unit
) : ReadCoordinator {
private abstract class ReadHandler<Type> : Holder<Continuation<Type>>(), CompletionHandler<Int, ByteBuffer> {
var required = 0
override fun failed(reason: Throwable, buffer: ByteBuffer) {
release().resumeWithException(reason)
}
}
private val arrayHandler = object : ReadHandler<ByteArray>() {
var offset = 0
var array: ByteArray? = null
operator fun invoke(
using: ByteBuffer,
array: ByteArray,
offset: Int,
amount: Int,
continuation: Continuation<ByteArray>
): Any {
if (!hold(continuation))
throw InProgressException()
val remaining = using.remaining()
if (remaining > 0) {
if (remaining >= amount) {
using.get(array, offset, amount)
this.offset = offset
return array
}
using.get(array, offset, remaining)
this.offset = offset + remaining
}
required = amount - remaining
this.array = array
using.flip()
read(using, this)
return COROUTINE_SUSPENDED
}
override fun completed(count: Int, buffer: ByteBuffer) {
val current = required
required -= count
if (required < 1) {
buffer.flip()
buffer.get(array, offset, current)
release().resume(array!!)
} else {
if (buffer.remaining() < required) {
val remaining = buffer.position()
buffer.flip()
buffer.get(array, offset, remaining).clear()
offset += remaining
}
read(buffer, this)
}
}
}
private val bufferHandler = object : ReadHandler<ByteBuffer>() {
operator fun invoke(
using: ByteBuffer,
destination: ByteBuffer,
continuation: Continuation<ByteBuffer>
): Any {
//TODO we could use required instead, but maybe nulling out is good?
if (!hold(continuation))
throw InProgressException()
if (using.hasRemaining())
destination.put(using)
required = destination.remaining()
return if (required < 1)
destination
else {
read(destination, this)
COROUTINE_SUSPENDED
}
}
override fun completed(count: Int, destination: ByteBuffer) {
required -= count
if (required < 1) {
release().resume(destination)
} else {
read(destination, this)
}
}
}
private val primitiveHandler = object : ReadHandler<Number>() {
var marked = false
var converter: ((ByteBuffer) -> Number)? = null
operator fun invoke(
using: ByteBuffer,
amount: Int,
converter: (ByteBuffer) -> Number,
continuation: Continuation<Number>
): Any {
val remaining = using.remaining()
return if (remaining >= amount)
converter(using)
else {
if (!hold(continuation))
throw InProgressException()
this.converter = converter
required = amount - remaining
val capacity = using.capacity()
val limit = using.limit()
if (remaining == 0)
using.clear()
if (capacity - limit < required)
using.compact()
else {
marked = true
using.mark().position(limit).limit(capacity)
}
read(using, this)
COROUTINE_SUSPENDED
}
}
override fun completed(count: Int, buffer: ByteBuffer) {
required -= count
if (required < 1) {
if (marked) {
marked = false
val limit = buffer.position()
buffer.reset()
buffer.limit(limit)
} else
buffer.flip()
release().resume(converter!!.invoke(buffer))
}
}
}
override fun buffer(
using: ByteBuffer,
destination: ByteBuffer,
continuation: Continuation<ByteBuffer>
) = bufferHandler(using, destination, continuation)
override fun array(
using: ByteBuffer,
array: ByteArray,
amount: Int,
offset: Int,
continuation: Continuation<ByteArray>
) = arrayHandler(using, array, offset, amount, continuation)
override fun <Type : Number> number(
using: ByteBuffer,
amount: Int,
reader: (ByteBuffer) -> Type,
continuation: Continuation<Type>
) = primitiveHandler(using, amount, reader, continuation as Continuation<Number>)
}
package com.mynt.network.providers
import com.mynt.network.Connection
import com.mynt.network.Provider
import com.mynt.network.implementation.continued
import com.mynt.network.implementation.sequential.SequentialRead
import com.mynt.network.implementation.sequential.SequentialReadCoordinator
import com.mynt.network.implementation.sequential.SequentialWrite
import com.mynt.network.implementation.sequential.SequentialWriteCoordinator
import kotlinx.coroutines.experimental.async
import java.lang.Boolean.TRUE
import java.net.SocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousChannelGroup
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MILLISECONDS
import kotlin.coroutines.experimental.Continuation
import kotlin.coroutines.experimental.intrinsics.COROUTINE_SUSPENDED
class TCPSocketProvider(
private val group: AsynchronousChannelGroup,
private val allocator: () -> ByteBuffer
) : Provider {
private val servers = HashMap<SocketAddress, AsynchronousServerSocketChannel>()
//TODO is this really a win?
private val serverFactory = { address: SocketAddress ->
group.provider().openAsynchronousServerSocketChannel(group).bind(address)
}
private open class Handler(
allocator: () -> ByteBuffer
) : Connection {
open lateinit var channel: AsynchronousSocketChannel
override val read = SequentialRead(SequentialReadCoordinator { buffer, handler ->
channel.read(buffer, buffer, handler)
}, allocator())
override val write = SequentialWrite(SequentialWriteCoordinator { buffer, handler ->
channel.write(buffer, buffer, handler)
}, allocator().flip() as ByteBuffer)
override val isOpen
get() = channel.isOpen
override fun close() = channel.close()
}
//--Accept--
private class AcceptHandler(
allocator: () -> ByteBuffer
) : Handler(allocator), CompletionHandler<AsynchronousSocketChannel, Continuation<Connection>> {
override fun completed(channel: AsynchronousSocketChannel, continuation: Continuation<Connection>) {
this.channel = channel
continuation.resume(this)
}
override fun failed(reason: Throwable, continuation: Continuation<Connection>) =
continuation.resumeWithException(reason)
}
override suspend fun accept(address: SocketAddress) = continued<Connection> {
servers.computeIfAbsent(address, serverFactory).accept(it, AcceptHandler(allocator))
COROUTINE_SUSPENDED
}
//--Connect--
private class ConnectHandler(
allocator: () -> ByteBuffer,
override var channel: AsynchronousSocketChannel
) : Handler(allocator), CompletionHandler<Void?, Continuation<Connection>> {
override fun completed(ignored: Void?, continuation: Continuation<Connection>) =
continuation.resume(this)
override fun failed(reason: Throwable, continuation: Continuation<Connection>) =
continuation.resumeWithException(reason)
}
override suspend fun connect(address: SocketAddress) = continued<Connection> {
val channel = group.provider().openAsynchronousSocketChannel(group)
channel.connect(address, it, ConnectHandler(allocator, channel))
COROUTINE_SUSPENDED
}
//--State--
override val isOpen
get() = !group.isTerminated
override fun close() = group.shutdown()
//TODO Maybe just make close do this?
suspend fun awaitClose(
period: Long = Long.MAX_VALUE,
units: TimeUnit = MILLISECONDS
) = continued<Boolean> {
close()
async(it.context) {
try {
group.awaitTermination(period, units)
it.resume(TRUE)
} catch (reason: Exception) {
it.resumeWithException(reason)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment