Skip to content

Instantly share code, notes, and snippets.

@kevinlynx
Created October 20, 2013 13:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kevinlynx/7069826 to your computer and use it in GitHub Desktop.
Save kevinlynx/7069826 to your computer and use it in GitHub Desktop.
request torrent meta info from peer based on netty and scala
/*
* http://www.bittorrent.org/beps/bep_0010.html
* http://www.bittorrent.org/beps/bep_0009.html
*/
package netty.sample
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.buffer.ByteBuf
import io.netty.bootstrap.Bootstrap
import io.netty.channel._
import io.netty.channel.nio._
import io.netty.channel.socket._
import io.netty.channel.socket.nio._
import io.netty.handler.codec._
import java.nio.ByteOrder._
import org.saunter.bencode._
import collection.immutable.HashMap
import java.io._
object Hex {
def fromString(hex: String): Array[Byte] = {
(for { i <- 0 to hex.length - 1 by 2 }
yield hex.substring(i, i + 2))
.map(Integer.parseInt(_, 16).toByte).toArray
}
def toString(bytes: Array[Byte]): String = {
def cvtByte(b: Byte): String = {
(if ((b & 0xff) < 0x10) "0" else "") + java.lang.Long.toHexString(b & 0xff)
}
bytes.map(cvtByte(_)).mkString.toLowerCase
}
}
// fc98937bf1447c367b828765d739093b666f35d1
class Handshake {
var hash = new Array[Byte](20)
var flag = Array[Byte](0, 0, 0, 0, 0, 16, 0, 0)
var peer = new Array[Byte](20)
val mark = "BitTorrent protocol".toArray.map(_.toByte)
def encode(buf: ByteBuf) = {
buf.writeByte(19)
buf.writeBytes(mark)
buf.writeBytes(flag)
buf.writeBytes(hash)
buf.writeBytes(peer)
}
def decode(buf: ByteBuf) = {
val markSize = buf.readByte()
assert(markSize == 19)
buf.readBytes(mark)
assert(mark.map(_.toChar).mkString == "BitTorrent protocol")
buf.readBytes(flag)
println("flag " + flag.mkString)
buf.readBytes(hash)
println("hash " + Hex.toString(hash))
buf.readBytes(peer)
println("peer " + Hex.toString(peer))
}
def this(buf: ByteBuf) = {
this()
decode(buf)
}
def this(hashStr: String) = {
this()
hash = Hex.fromString(hashStr)
}
}
class ExtMessage(mId: Byte) {
val myMsgId: Byte = 1
var kind: Byte = 20
var body:Map[String, Any] = null
var msgId: Byte = mId
var piece: Int = 0
def decode(len: Int, buf: ByteBuf, metaSize: Int, piece: Int): Unit = {
kind = buf.readByte()
if (kind != 20) {
println("skip bytes " + len)
buf.skipBytes(len - 1)
println("not support message " + kind)
return
}
msgId = buf.readByte()
assert(msgId == 0 || msgId > 0 && msgId == myMsgId)
val leftSize = len - 1 - 1
val thisPiece = if ((1 + piece) * 16 * 1024 > metaSize)
metaSize - piece * 16 * 1024 else 16 * 1024
val bencodeSize = if (msgId == 0) leftSize
else leftSize - thisPiece
val bodyByte = new Array[Byte](bencodeSize)
buf.readBytes(bodyByte)
val bodyStr = bodyByte.map(_.toChar).mkString
println("body: " + bodyStr)
body = BencodeDecoder.decode(bodyStr).get.asInstanceOf[Map[String, Any]]
if (msgId > 0) { // meta info
val meta = new Array[Byte](thisPiece)
buf.readBytes(meta)
// test
val stream = new FileOutputStream("metainfo.torrent", piece != 0)
stream.write(meta)
stream.close()
}
}
def encode(buf: ByteBuf) = {
if (msgId == 0) {
body = Map("m" -> Map("ut_metadata" -> myMsgId.asInstanceOf[Int]))
} else {
body = Map("msg_type" -> 0, "piece" -> piece)
}
val str = BencodeEncoder.encode(body)
println("encoded body " + str)
val encoded = str.toArray.map(_.toByte)
val len = 1 + 1 + encoded.size
buf.writeInt(len)
buf.writeByte(20)
buf.writeByte(msgId)
buf.writeBytes(encoded)
}
}
class BTClientHandler extends ChannelInboundHandlerAdapter {
var msgId: Byte = 0
override def channelActive(ctx: ChannelHandlerContext) {
println("channel active")
ctx.write(new Handshake("fc98937bf1447c367b828765d739093b666f35d1"))
ctx.writeAndFlush(new ExtMessage(0))
}
override def channelRead(ctx: ChannelHandlerContext, msg: Object) {
msg match {
case hanshake: Handshake => println("decode a handshake")
case ext: ExtMessage =>
println("decode an ext-message: " + ext.kind)
if (ext.kind == 20) processExtMsg(ctx, ext)
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
cause.printStackTrace
ctx.close
}
private def processExtMsg(ctx: ChannelHandlerContext, msg: ExtMessage) = {
assert(msg.kind == 20)
if (msg.msgId == 0) {
val m = msg.body.get("m").get.asInstanceOf[Map[String, Any]]
msgId = m.get("ut_metadata").get.asInstanceOf[Long].toByte
val request = new ExtMessage(msgId)
request.piece = 0
println("start to request metadata")
ctx.writeAndFlush(request)
} else {
println("msg.msgId " + msg.msgId)
}
}
}
class ProtocolDecoder extends ByteToMessageDecoder {
var seq = 0
var totalSize:Int = 0
var piece: Int = 0
var msgId: Byte = 0
override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: java.util.List[Object]) = {
val readable = in.readableBytes
if (seq == 0 && readable >= 68) {
println("recv handshake")
out.add(new Handshake(in))
seq += 1
} else if (seq > 0 && readable >= 4) {
val len = in.readInt
assert(len > 0)
if (readable >= len) {
val ext = new ExtMessage(0)
ext.decode(len, in, totalSize, piece)
if (ext.kind == 20 && ext.msgId > 0) {
piece += 1
if (piece * 16 * 1024 - totalSize <= 16 * 1024) {
val request = new ExtMessage(msgId)
request.piece = piece
println("continue to request piece " + piece)
ctx.writeAndFlush(request)
}
}
readTotalSize(ext)
out.add(ext)
seq += 1
} else {
in.resetReaderIndex()
}
}
}
private def readTotalSize(ext: ExtMessage) = {
if (ext.kind == 20 && ext.msgId == 0) {
val m = ext.body.get("m").get.asInstanceOf[Map[String, Any]]
msgId = m.get("ut_metadata").get.asInstanceOf[Long].toByte
totalSize = ext.body.get("metadata_size").get.asInstanceOf[Long].toInt
println("totalSize " + totalSize)
println("msgId " + msgId)
}
}
}
class HandshakeEncoder extends MessageToByteEncoder[Handshake] {
def encode(ctx: ChannelHandlerContext, msg: Handshake, out: ByteBuf) {
println("encode handshake message")
msg.encode(out)
}
}
class ExtMsgEncoder extends MessageToByteEncoder[ExtMessage] {
def encode(ctx: ChannelHandlerContext, msg: ExtMessage, out: ByteBuf) {
println("encode ext message")
msg.encode(out)
}
}
object BTClient {
def main(args: Array[String]) {
val workerGroup = new NioEventLoopGroup
try {
val boot = new Bootstrap()
boot.group(workerGroup)
.channel((classOf[NioSocketChannel]))
.handler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel) {
ch.pipeline().addLast(new HandshakeEncoder(), new ExtMsgEncoder(),
new ProtocolDecoder(),
new BTClientHandler())
}
})
.option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)
println("connect to server")
val f: ChannelFuture = boot.connect("localhost", 6776)
f.addListener(new ChannelFutureListener() {
override def operationComplete(f: ChannelFuture) = f.isSuccess() match {
case true => println("connect success")
case false => println("connect failed " + f.cause().getStackTraceString)
}
})
f.channel().closeFuture().sync()
println("client exit")
} finally {
workerGroup.shutdownGracefully()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment