Skip to content

Instantly share code, notes, and snippets.

@I-Info
Last active Feb 17, 2022
Embed
What would you like to do?
A Protocol define & implement with JAVA NIO SocketChannel
package com.i1nfo.pbtest.server
import com.i1nfo.pbtest.Protocol
import java.io.IOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
import java.nio.channels.Selector
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel
import java.util.*
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
class ProtocolHandler(
private val socketChannel: SocketChannel,
private val key: SelectionKey,
private val readHandler: ProtocolHandler.(type: Byte, data: ByteArray) -> Unit
) {
companion object {
private const val BUFFER_SIZE_READ = 4 * 1024
}
enum class Status {
WAITING, STARTED, READING, ESCAPE
}
private val readBuffer = ByteBuffer.allocateDirect(BUFFER_SIZE_READ)
private val writeQueue = LinkedList<ByteBuffer>()
private var valid = true
private var status = Status.WAITING
private var type: Byte = 0
private val byteRead = arrayListOf<Byte>()
private val lock = ReentrantLock()
fun handleReadEvent() {
try {
val len = socketChannel.read(readBuffer)
if (len == -1) {
// disconnected
socketChannel.close()
key.cancel()
} else if (len > 0) {
readBuffer.flip() // Important
while (readBuffer.hasRemaining()) {
when (status) {
Status.WAITING -> {
val byte = readBuffer.get()
if (byte == Protocol.START) {
status = Status.STARTED
byteRead.clear()
if (readBuffer.remaining() < Protocol.HEADER_LENGTH) {
// no enough header data
break
}
}
}
Status.STARTED -> {
if (readBuffer.limit() < Protocol.HEADER_LENGTH) {
// no enough header data
break
}
val version = readBuffer.get()
// validate version
valid = version == Protocol.VERSION
type = readBuffer.get()
status = Status.READING
}
Status.READING -> {
val byte = readBuffer.get()
if (byte == Protocol.END) {
// just drop invalid data pack
if (valid) {
readHandler(type, byteRead.toByteArray())
}
status = Status.WAITING
break
} else if (byte == Protocol.ESCAPE) {
status = Status.ESCAPE
} else if (valid) {
byteRead.add(byte)
}
}
Status.ESCAPE -> {
if (valid) {
byteRead.add(readBuffer.get())
}
}
}
}
readBuffer.compact()
}
} catch (e: IOException) {
socketChannel.close()
key.cancel()
}
}
fun handleWriteEvent() {
try {
// sync for concurrent circumstances
lock.withLock {
while (true) {
val writeBuffer: ByteBuffer = writeQueue.peek()
val written = socketChannel.write(writeBuffer)
if (written == -1) {
socketChannel.close()
key.cancel()
break
} else if (written == 0) {
break
} else if (!writeBuffer.hasRemaining()) {
// remove first
writeQueue.poll()
if (writeQueue.isEmpty()) {
// cancel OP_WRITE
key.interestOpsAnd(SelectionKey.OP_WRITE.inv())
break
}
}
}
}
} catch (e: IOException) {
socketChannel.close()
key.cancel()
}
}
private fun packData(type: Byte, data: ByteArray): ByteArray {
val packBuilder = arrayListOf(Protocol.START, Protocol.VERSION, type)
for (i in data) {
when (i) {
Protocol.START -> packBuilder.addAll(arrayListOf(Protocol.ESCAPE, Protocol.START))
Protocol.END -> packBuilder.addAll(arrayListOf(Protocol.ESCAPE, Protocol.END))
Protocol.ESCAPE -> packBuilder.addAll(arrayListOf(Protocol.ESCAPE, Protocol.ESCAPE))
else -> packBuilder.add(i)
}
}
packBuilder.add(Protocol.END)
return packBuilder.toByteArray()
}
/*
* write at any circumstance
*/
fun write(type: Byte, data: ByteArray) {
if (data.isEmpty()) {
return
}
val packedData = ByteBuffer.wrap(packData(type, data))
lock.withLock {
if (writeQueue.isEmpty()) {
// no need to queue
while (packedData.hasRemaining()) {
val written = socketChannel.write(packedData)
if (written <= 0) {
writeQueue.offer(packedData)
key.interestOpsOr(SelectionKey.OP_WRITE)
key.selector().wakeup()
break
}
}
} else {
writeQueue.offer(packedData)
key.interestOpsOr(SelectionKey.OP_WRITE)
key.selector().wakeup()
}
}
}
}
fun main() {
val selector = Selector.open()
val serverSocketChannel = ServerSocketChannel.open()
serverSocketChannel.bind(InetSocketAddress("127.0.0.1", 9000))
serverSocketChannel.configureBlocking(false)
println("Start listening at ${serverSocketChannel.localAddress} ...")
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT, Runnable {
val socketChannel = serverSocketChannel.accept()
socketChannel.configureBlocking(false)
println("Client from ${socketChannel.remoteAddress} connected.")
val key = socketChannel.register(selector, SelectionKey.OP_READ)
val protocolHandler = ProtocolHandler(socketChannel, key) { type, data ->
write(type, data)
}
key.attach(protocolHandler)
})
while (true) {
selector.select()
val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
if (key.isAcceptable) {
(key.attachment() as Runnable?)?.run()
} else if (key.isReadable) {
(key.attachment() as ProtocolHandler).handleReadEvent()
} else if (key.isValid && key.isWritable) {
(key.attachment() as ProtocolHandler).handleWriteEvent()
}
iterator.remove()
}
}
}
package com.i1nfo.pbtest
import java.io.IOException
import java.io.InputStream
object Protocol {
const val VERSION: Byte = 0x01
const val START: Byte = 0x40
const val END: Byte = 0x04
const val ESCAPE: Byte = 0x1B
const val HEADER_LENGTH = 2
/*
* 协议定义
* 0x40 0xNN 0xNN … 0x1b 0x04
* 起始 版本号 数据类型 数据 转义 结束
*/
private fun getHeader(type: Byte) =
arrayListOf(START, VERSION, type)
fun pack(type: Byte, data: ByteArray): ByteArray {
if (data.isEmpty()) {
throw IOException()
}
val temp = getHeader(type)
for (i in data) {
when (i) {
START -> temp.addAll(arrayListOf(ESCAPE, START))
END -> temp.addAll(arrayListOf(ESCAPE, END))
ESCAPE -> temp.addAll(arrayListOf(ESCAPE, ESCAPE))
else -> temp.add(i)
}
}
temp.add(END)
return temp.toByteArray()
}
fun unpack(inputStream: InputStream): Pair<Byte, ByteArray> {
val start = START.toInt()
val esc = ESCAPE.toInt()
val end = END.toInt()
while (true) {
val s = inputStream.read()
if (s == -1) {
throw IOException("EOF")
}
if (s != start) {
continue
}
val version = inputStream.read()
if (version.toByte() != VERSION) {
do {
val current = inputStream.read()
if (current == esc) {
inputStream.read()
}
} while (current != end)
continue
}
val type = inputStream.read().toByte()
val data = ArrayList<Byte>()
while (true) {
val current = inputStream.read()
if (current == esc)
data.add(inputStream.read().toByte())
else if (current == end) {
break
} else {
data.add(current.toByte())
}
}
return Pair(type, data.toByteArray())
}
}
}
package com.i1nfo.pbtest.client
import com.i1nfo.pbtest.BasicMsgOuterClass
import com.i1nfo.pbtest.Protocol
import com.i1nfo.pbtest.basicMsg
import java.io.IOException
import java.net.Socket
import kotlin.system.exitProcess
fun main() {
println("connecting to server...")
val socket: Socket
try {
socket = Socket("127.0.0.1", 9000)
} catch (_: Exception) {
println("connect failed")
return
}
println("connected.")
val outputStream = socket.getOutputStream()
val inputStream = socket.getInputStream()
Thread {
while (true) {
try {
val msg = Protocol.unpack(inputStream)
if (msg.first == 1.toByte()) {
println(BasicMsgOuterClass.BasicMsg.parseFrom(msg.second))
}
} catch (_: IOException) {
println("disconnected from server.")
exitProcess(0)
}
}
}.start()
while (true) {
val data = Protocol.pack(1, basicMsg { code = 1; msg = readln() }.toByteArray())
outputStream.write(data)
outputStream.flush()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment