A Protocol define & implement with JAVA NIO SocketChannel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.i1nfo.pbtest.server | |
import com.i1nfo.pbtest.Protocol | |
import java.io.IOException | |
import java.net.InetSocketAddress | |
import java.nio.ByteBuffer | |
import java.nio.channels.SelectionKey | |
import java.nio.channels.Selector | |
import java.nio.channels.ServerSocketChannel | |
import java.nio.channels.SocketChannel | |
import java.util.* | |
import java.util.concurrent.locks.ReentrantLock | |
import kotlin.concurrent.withLock | |
class ProtocolHandler( | |
private val socketChannel: SocketChannel, | |
private val key: SelectionKey, | |
private val readHandler: ProtocolHandler.(type: Byte, data: ByteArray) -> Unit | |
) { | |
companion object { | |
private const val BUFFER_SIZE_READ = 4 * 1024 | |
} | |
enum class Status { | |
WAITING, STARTED, READING, ESCAPE | |
} | |
private val readBuffer = ByteBuffer.allocateDirect(BUFFER_SIZE_READ) | |
private val writeQueue = LinkedList<ByteBuffer>() | |
private var valid = true | |
private var status = Status.WAITING | |
private var type: Byte = 0 | |
private val byteRead = arrayListOf<Byte>() | |
private val lock = ReentrantLock() | |
fun handleReadEvent() { | |
try { | |
val len = socketChannel.read(readBuffer) | |
if (len == -1) { | |
// disconnected | |
socketChannel.close() | |
key.cancel() | |
} else if (len > 0) { | |
readBuffer.flip() // Important | |
while (readBuffer.hasRemaining()) { | |
when (status) { | |
Status.WAITING -> { | |
val byte = readBuffer.get() | |
if (byte == Protocol.START) { | |
status = Status.STARTED | |
byteRead.clear() | |
if (readBuffer.remaining() < Protocol.HEADER_LENGTH) { | |
// no enough header data | |
break | |
} | |
} | |
} | |
Status.STARTED -> { | |
if (readBuffer.limit() < Protocol.HEADER_LENGTH) { | |
// no enough header data | |
break | |
} | |
val version = readBuffer.get() | |
// validate version | |
valid = version == Protocol.VERSION | |
type = readBuffer.get() | |
status = Status.READING | |
} | |
Status.READING -> { | |
val byte = readBuffer.get() | |
if (byte == Protocol.END) { | |
// just drop invalid data pack | |
if (valid) { | |
readHandler(type, byteRead.toByteArray()) | |
} | |
status = Status.WAITING | |
break | |
} else if (byte == Protocol.ESCAPE) { | |
status = Status.ESCAPE | |
} else if (valid) { | |
byteRead.add(byte) | |
} | |
} | |
Status.ESCAPE -> { | |
if (valid) { | |
byteRead.add(readBuffer.get()) | |
} | |
} | |
} | |
} | |
readBuffer.compact() | |
} | |
} catch (e: IOException) { | |
socketChannel.close() | |
key.cancel() | |
} | |
} | |
fun handleWriteEvent() { | |
try { | |
// sync for concurrent circumstances | |
lock.withLock { | |
while (true) { | |
val writeBuffer: ByteBuffer = writeQueue.peek() | |
val written = socketChannel.write(writeBuffer) | |
if (written == -1) { | |
socketChannel.close() | |
key.cancel() | |
break | |
} else if (written == 0) { | |
break | |
} else if (!writeBuffer.hasRemaining()) { | |
// remove first | |
writeQueue.poll() | |
if (writeQueue.isEmpty()) { | |
// cancel OP_WRITE | |
key.interestOpsAnd(SelectionKey.OP_WRITE.inv()) | |
break | |
} | |
} | |
} | |
} | |
} catch (e: IOException) { | |
socketChannel.close() | |
key.cancel() | |
} | |
} | |
private fun packData(type: Byte, data: ByteArray): ByteArray { | |
val packBuilder = arrayListOf(Protocol.START, Protocol.VERSION, type) | |
for (i in data) { | |
when (i) { | |
Protocol.START -> packBuilder.addAll(arrayListOf(Protocol.ESCAPE, Protocol.START)) | |
Protocol.END -> packBuilder.addAll(arrayListOf(Protocol.ESCAPE, Protocol.END)) | |
Protocol.ESCAPE -> packBuilder.addAll(arrayListOf(Protocol.ESCAPE, Protocol.ESCAPE)) | |
else -> packBuilder.add(i) | |
} | |
} | |
packBuilder.add(Protocol.END) | |
return packBuilder.toByteArray() | |
} | |
/* | |
* write at any circumstance | |
*/ | |
fun write(type: Byte, data: ByteArray) { | |
if (data.isEmpty()) { | |
return | |
} | |
val packedData = ByteBuffer.wrap(packData(type, data)) | |
lock.withLock { | |
if (writeQueue.isEmpty()) { | |
// no need to queue | |
while (packedData.hasRemaining()) { | |
val written = socketChannel.write(packedData) | |
if (written <= 0) { | |
writeQueue.offer(packedData) | |
key.interestOpsOr(SelectionKey.OP_WRITE) | |
key.selector().wakeup() | |
break | |
} | |
} | |
} else { | |
writeQueue.offer(packedData) | |
key.interestOpsOr(SelectionKey.OP_WRITE) | |
key.selector().wakeup() | |
} | |
} | |
} | |
} | |
fun main() { | |
val selector = Selector.open() | |
val serverSocketChannel = ServerSocketChannel.open() | |
serverSocketChannel.bind(InetSocketAddress("127.0.0.1", 9000)) | |
serverSocketChannel.configureBlocking(false) | |
println("Start listening at ${serverSocketChannel.localAddress} ...") | |
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT, Runnable { | |
val socketChannel = serverSocketChannel.accept() | |
socketChannel.configureBlocking(false) | |
println("Client from ${socketChannel.remoteAddress} connected.") | |
val key = socketChannel.register(selector, SelectionKey.OP_READ) | |
val protocolHandler = ProtocolHandler(socketChannel, key) { type, data -> | |
write(type, data) | |
} | |
key.attach(protocolHandler) | |
}) | |
while (true) { | |
selector.select() | |
val iterator = selector.selectedKeys().iterator() | |
while (iterator.hasNext()) { | |
val key = iterator.next() | |
if (key.isAcceptable) { | |
(key.attachment() as Runnable?)?.run() | |
} else if (key.isReadable) { | |
(key.attachment() as ProtocolHandler).handleReadEvent() | |
} else if (key.isValid && key.isWritable) { | |
(key.attachment() as ProtocolHandler).handleWriteEvent() | |
} | |
iterator.remove() | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.i1nfo.pbtest | |
import java.io.IOException | |
import java.io.InputStream | |
object Protocol { | |
const val VERSION: Byte = 0x01 | |
const val START: Byte = 0x40 | |
const val END: Byte = 0x04 | |
const val ESCAPE: Byte = 0x1B | |
const val HEADER_LENGTH = 2 | |
/* | |
* 协议定义 | |
* 0x40 0xNN 0xNN … 0x1b 0x04 | |
* 起始 版本号 数据类型 数据 转义 结束 | |
*/ | |
private fun getHeader(type: Byte) = | |
arrayListOf(START, VERSION, type) | |
fun pack(type: Byte, data: ByteArray): ByteArray { | |
if (data.isEmpty()) { | |
throw IOException() | |
} | |
val temp = getHeader(type) | |
for (i in data) { | |
when (i) { | |
START -> temp.addAll(arrayListOf(ESCAPE, START)) | |
END -> temp.addAll(arrayListOf(ESCAPE, END)) | |
ESCAPE -> temp.addAll(arrayListOf(ESCAPE, ESCAPE)) | |
else -> temp.add(i) | |
} | |
} | |
temp.add(END) | |
return temp.toByteArray() | |
} | |
fun unpack(inputStream: InputStream): Pair<Byte, ByteArray> { | |
val start = START.toInt() | |
val esc = ESCAPE.toInt() | |
val end = END.toInt() | |
while (true) { | |
val s = inputStream.read() | |
if (s == -1) { | |
throw IOException("EOF") | |
} | |
if (s != start) { | |
continue | |
} | |
val version = inputStream.read() | |
if (version.toByte() != VERSION) { | |
do { | |
val current = inputStream.read() | |
if (current == esc) { | |
inputStream.read() | |
} | |
} while (current != end) | |
continue | |
} | |
val type = inputStream.read().toByte() | |
val data = ArrayList<Byte>() | |
while (true) { | |
val current = inputStream.read() | |
if (current == esc) | |
data.add(inputStream.read().toByte()) | |
else if (current == end) { | |
break | |
} else { | |
data.add(current.toByte()) | |
} | |
} | |
return Pair(type, data.toByteArray()) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.i1nfo.pbtest.client | |
import com.i1nfo.pbtest.BasicMsgOuterClass | |
import com.i1nfo.pbtest.Protocol | |
import com.i1nfo.pbtest.basicMsg | |
import java.io.IOException | |
import java.net.Socket | |
import kotlin.system.exitProcess | |
fun main() { | |
println("connecting to server...") | |
val socket: Socket | |
try { | |
socket = Socket("127.0.0.1", 9000) | |
} catch (_: Exception) { | |
println("connect failed") | |
return | |
} | |
println("connected.") | |
val outputStream = socket.getOutputStream() | |
val inputStream = socket.getInputStream() | |
Thread { | |
while (true) { | |
try { | |
val msg = Protocol.unpack(inputStream) | |
if (msg.first == 1.toByte()) { | |
println(BasicMsgOuterClass.BasicMsg.parseFrom(msg.second)) | |
} | |
} catch (_: IOException) { | |
println("disconnected from server.") | |
exitProcess(0) | |
} | |
} | |
}.start() | |
while (true) { | |
val data = Protocol.pack(1, basicMsg { code = 1; msg = readln() }.toByteArray()) | |
outputStream.write(data) | |
outputStream.flush() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment