Skip to content

Instantly share code, notes, and snippets.

@kevinlynx
Created October 25, 2013 14:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kevinlynx/7155551 to your computer and use it in GitHub Desktop.
Save kevinlynx/7155551 to your computer and use it in GitHub Desktop.
the second version to request torrent metainfo by protocol bep09/10
/*
* http://www.bittorrent.org/beps/bep_0010.html
* http://www.bittorrent.org/beps/bep_0009.html
*/
package bep.impl
import java.io._
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.buffer._
import io.netty.bootstrap.Bootstrap
import io.netty.channel._
import io.netty.channel.nio._
import io.netty.channel.socket._
import io.netty.channel.socket.nio._
import io.netty.handler.codec._
import java.nio.ByteOrder._
import org.saunter.bencode._
import collection.immutable.HashMap
object Hex {
def fromString(hex: String): Array[Byte] = {
(for { i <- 0 to hex.length - 1 by 2 }
yield hex.substring(i, i + 2))
.map(Integer.parseInt(_, 16).toByte).toArray
}
def toString(bytes: Array[Byte]): String = {
def cvtByte(b: Byte): String = {
(if ((b & 0xff) < 0x10) "0" else "") + java.lang.Long.toHexString(b & 0xff)
}
bytes.map(cvtByte(_)).mkString.toLowerCase
}
}
class Handshake(val hash: Array[Byte], val peer: Array[Byte]) {
var flag = Array[Byte](0, 0, 0, 0, 0, 16, 0, 0)
assert(hash.length == 20)
assert(peer.length == 20)
def this(hashStr: String) = this(Hex.fromString(hashStr), new Array[Byte](20))
def encode(out: DataOutput) = {
out.writeByte(19)
out.write(Handshake.mark)
out.write(flag)
out.write(hash)
out.write(peer)
}
}
object Handshake {
val mark = "BitTorrent protocol".toArray.map(_.toByte)
val length = 1 + 19 + 8 + 20 + 20
def enough(bytes: Int) = bytes >= length
def apply(in: DataInput) = {
println("decode handshake")
val markSize = in.readByte()
assert(markSize == 19)
in.skipBytes(markSize)
in.skipBytes(8)
val hash = new Array[Byte](20)
in.readFully(hash)
val peer = new Array[Byte](20)
in.readFully(peer)
new Handshake(hash, peer)
}
}
class ExtMessage(val msgId: Byte, val body: Map[String, Any]) {
def encode(out: DataOutput) = {
val str = BencodeEncoder.encode(body)
val encoded = str.toArray.map(_.toByte)
val len = 1 + 1 + encoded.size
out.writeInt(len)
out.writeByte(20)
out.writeByte(msgId)
out.write(encoded)
}
}
object ExtMessage {
val MY_MSG_ID: Byte = 1
val SIZE_PER_PIECE = 16*1024
// handshake
def apply() = new ExtMessage(0, Map("m" -> Map("ut_metadata" -> MY_MSG_ID.toInt)))
// request
def apply(msgId: Byte, piece: Int) = new ExtMessage(msgId, Map("msg_type" -> 0, "piece" -> piece))
def decode(size: Int, in: DataInput): (Symbol, Int) = in.readByte match {
case 20 =>
val msgId = in.readByte()
do_decode(size - 2, msgId, in)
case kind =>
in.skipBytes(size - 1)
('unknown, kind)
}
def decodeHandshake(size: Int, in: DataInput) = {
val body = new Array[Byte](size)
in.readFully(body)
val tbl = BencodeDecoder.decode(body.map(_.toChar).mkString).get.asInstanceOf[Map[String, Any]]
val m = tbl.get("m").get.asInstanceOf[Map[String, Any]]
val msgId = m.get("ut_metadata").get.asInstanceOf[Long].toByte
val metaSize = tbl.get("metadata_size").get.asInstanceOf[Long].toInt
(msgId, metaSize)
}
def decodeMetaData(size: Int, metaSize: Int, piece: Int, in: DataInput) = {
val totalPiece = metaSize / SIZE_PER_PIECE + 1
val pieceSize = if ((1 + piece) * SIZE_PER_PIECE > metaSize)
metaSize - piece * SIZE_PER_PIECE else SIZE_PER_PIECE
val bodySize = size - pieceSize
assert(bodySize > 0)
val body = new Array[Byte](bodySize)
in.readFully(body)
val data = new Array[Byte](pieceSize)
in.readFully(data)
data
}
private def do_decode(size: Int, msgId: Byte, in: DataInput) = msgId match {
case 0 => ('exthandshake, size)
case MY_MSG_ID => ('extend, size)
}
}
// represents a connection, which can save state here
class ClientHandler extends ChannelInboundHandlerAdapter {
var msgId: Byte = 0
var metaSize: Int = 0
var piece: Int = 0
var metaInfo: Array[Byte] = Array()
override def channelActive(ctx: ChannelHandlerContext) {
ctx.write(new Handshake("fc98937bf1447c367b828765d739093b666f35d1"))
ctx.writeAndFlush(ExtMessage())
}
override def channelRead(ctx: ChannelHandlerContext, msg: Object) = msg match {
case ('exthandshake, msgId: Byte, metaSize: Int) =>
this.msgId = msgId
this.metaSize = metaSize
println("request first piece")
ctx.writeAndFlush(ExtMessage(this.msgId, 0))
case ('metadata, data: Array[Byte]) =>
metaInfo = metaInfo ++ data
piece += 1
if (piece * ExtMessage.SIZE_PER_PIECE >= metaSize) {
dumpMeta
ctx.close
} else {
println("request piece " + piece)
ctx.writeAndFlush(ExtMessage(this.msgId, piece))
}
case h: Handshake => println("recv handshake")
case ('unknown, kind) => println("recv unknown msg: " + kind)
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
cause.printStackTrace
ctx.close
}
private def dumpMeta = {
println("download metainfo done")
val stream = new FileOutputStream("metainfo.torrent")
stream.write(metaInfo)
stream.close()
}
}
class MessageDecoder extends ByteToMessageDecoder {
var metaSize = 0
var piece = 0
var first = true
override def decode(ctx: ChannelHandlerContext, buf: ByteBuf, out: java.util.List[Object]) = {
val in = new ByteBufInputStream(buf)
if (first) {
assert(Handshake.enough(in.available()))
out.add(Handshake(in))
first = false
} else if (in.available() > 4) {
val pos = buf.readerIndex()
in.readInt match {
case size if (in.available() >= size) =>
val msg = decodeFixed(ctx, size, in)
out.add(msg)
case _ =>
buf.readerIndex(pos)
}
}
}
private def decodeFixed(ctx: ChannelHandlerContext, size: Int, in: DataInput) =
ExtMessage.decode(size, in) match {
case ('unknown, kind) => ('unknown, kind)
case ('exthandshake, remain) =>
val (msgId, metaSize) = ExtMessage.decodeHandshake(remain, in)
this.metaSize = metaSize
('exthandshake, msgId, metaSize)
case ('extend, remain) =>
val data = ExtMessage.decodeMetaData(remain, metaSize, piece, in)
piece += 1
('metadata, data)
}
}
class HandshakeEncoder extends MessageToByteEncoder[Handshake] {
def encode(ctx: ChannelHandlerContext, msg: Handshake, out: ByteBuf) {
msg.encode(new ByteBufOutputStream(out))
}
}
class ExtMsgEncoder extends MessageToByteEncoder[ExtMessage] {
def encode(ctx: ChannelHandlerContext, msg: ExtMessage, out: ByteBuf) {
msg.encode(new ByteBufOutputStream(out))
}
}
object Client {
def main(args: Array[String]) {
val workerGroup = new NioEventLoopGroup
try {
val boot = new Bootstrap()
boot.group(workerGroup)
.channel((classOf[NioSocketChannel]))
.handler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel) {
ch.pipeline().addLast(new HandshakeEncoder(), new ExtMsgEncoder(),
new MessageDecoder(),
new ClientHandler())
}
})
.option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)
val f: ChannelFuture = boot.connect("localhost", 6776)
f.addListener(new ChannelFutureListener() {
override def operationComplete(f: ChannelFuture) = f.isSuccess() match {
case true => println("connect success")
case false => println("connect failed " + f.cause().getStackTraceString)
}
})
f.channel().closeFuture().sync()
println("client exit")
} finally {
workerGroup.shutdownGracefully()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment